forked from M-Labs/nac3
1
0
Fork 0

core/codegen: refactor gen_{for,comprehension} to match on iter type

This commit is contained in:
lyken 2024-08-01 12:25:10 +08:00 committed by sb10q
parent 669c6aca6b
commit 894083c6a3
3 changed files with 235 additions and 191 deletions

View File

@ -995,8 +995,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx.builder.position_at_end(init_bb); ctx.builder.position_at_end(init_bb);
let Comprehension { target, iter, ifs, .. } = &generators[0]; let Comprehension { target, iter, ifs, .. } = &generators[0];
let iter_ty = iter.custom.unwrap();
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? v.to_basic_value_enum(ctx, generator, iter_ty)?
} else { } else {
for bb in [test_bb, body_bb, cont_bb] { for bb in [test_bb, body_bb, cont_bb] {
ctx.builder.position_at_end(bb); ctx.builder.position_at_end(bb);
@ -1014,96 +1016,120 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx.builder.build_store(index, zero_size_t).unwrap(); ctx.builder.build_store(index, zero_size_t).unwrap();
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap()); let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
let list; let list;
if is_range { match &*ctx.unifier.get_ty(iter_ty) {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); TypeEnum::TObj { obj_id, .. }
let (start, stop, step) = destructure_range(ctx, iter_val); if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap(); {
// add 1 to the length as the value is rounded to zero let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// the length may be 1 more than the actual length if the division is exact, but the let (start, stop, step) = destructure_range(ctx, iter_val);
// length is a upper bound only anyway so it does not matter. let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap(); // add 1 to the length as the value is rounded to zero
let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap(); // the length may be 1 more than the actual length if the division is exact, but the
// in case length is non-positive // length is a upper bound only anyway so it does not matter.
let is_valid = let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap(); let length =
ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
// in case length is non-positive
let is_valid =
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
let list_alloc_size = ctx let list_alloc_size = ctx
.builder .builder
.build_select( .build_select(
is_valid, is_valid,
ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(), ctx.builder
zero_size_t, .build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len")
"listcomp.alloc_size", .unwrap(),
) zero_size_t,
.unwrap(); "listcomp.alloc_size",
list = allocate_list( )
generator, .unwrap();
ctx, list = allocate_list(
Some(elem_ty), generator,
list_alloc_size.into_int_value(), ctx,
Some("listcomp.addr"), Some(elem_ty),
); list_alloc_size.into_int_value(),
Some("listcomp.addr"),
);
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap(); let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
ctx.builder ctx.builder
.build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap()) .build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap())
.unwrap(); .unwrap();
ctx.builder ctx.builder
.build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb) .build_conditional_branch(
.unwrap(); gen_in_range_check(ctx, start, stop, step),
test_bb,
cont_bb,
)
.unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
// add and test // add and test
let tmp = ctx let tmp = ctx
.builder .builder
.build_int_add( .build_int_add(
ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(),
step, step,
"start_loop", "start_loop",
) )
.unwrap(); .unwrap();
ctx.builder.build_store(i, tmp).unwrap(); ctx.builder.build_store(i, tmp).unwrap();
ctx.builder ctx.builder
.build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb) .build_conditional_branch(
.unwrap(); gen_in_range_check(ctx, tmp, stop, step),
body_bb,
cont_bb,
)
.unwrap();
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
} else { }
let length = ctx TypeEnum::TObj { obj_id, .. }
.build_gep_and_load( if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
iter_val.into_pointer_value(), {
&[zero_size_t, int32.const_int(1, false)], let length = ctx
Some("length"), .build_gep_and_load(
) iter_val.into_pointer_value(),
.into_int_value(); &[zero_size_t, int32.const_int(1, false)],
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp")); Some("length"),
)
.into_int_value();
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp"));
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?; let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
// counter = -1 // counter = -1
ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap(); ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap();
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap(); let tmp =
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap(); ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
ctx.builder.build_store(counter, tmp).unwrap(); let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap(); ctx.builder.build_store(counter, tmp).unwrap();
ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap(); let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap();
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
let arr_ptr = ctx let arr_ptr = ctx
.build_gep_and_load( .build_gep_and_load(
iter_val.into_pointer_value(), iter_val.into_pointer_value(),
&[zero_size_t, zero_32], &[zero_size_t, zero_32],
Some("arr.addr"), Some("arr.addr"),
) )
.into_pointer_value(); .into_pointer_value();
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val")); let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
generator.gen_assign(ctx, target, val.into())?; generator.gen_assign(ctx, target, val.into())?;
}
_ => {
panic!(
"unsupported list comprehension iterator type: {}",
ctx.unifier.stringify(iter_ty)
);
}
} }
// Emits the content of `cont_bb` // Emits the content of `cont_bb`

View File

@ -315,9 +315,6 @@ pub fn gen_for<G: CodeGenerator>(
let orelse_bb = let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") }; if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
// The BB containing the increment expression // The BB containing the increment expression
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr"); let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
// The BB containing the loop condition check // The BB containing the loop condition check
@ -326,113 +323,132 @@ pub fn gen_for<G: CodeGenerator>(
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_ty = iter.custom.unwrap();
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())? v.to_basic_value_enum(ctx, generator, iter_ty)?
} else { } else {
return Ok(()); return Ok(());
}; };
if is_iterable_range_expr {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else {
unreachable!()
};
let (start, stop, step) = destructure_range(ctx, iter_val);
ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised."
let rangenez =
ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
ctx.make_assert(
generator,
rangenez,
"ValueError",
"range() arg 3 must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
match &*ctx.unifier.get_ty(iter_ty) {
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
{ {
ctx.builder.position_at_end(cond_bb); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
ctx.builder // Internal variable for loop; Cannot be assigned
.build_conditional_branch( let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
gen_in_range_check( // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
ctx, let Some(target_i) =
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), generator.gen_store_target(ctx, target, Some("for.target.addr"))?
stop, else {
step, unreachable!()
), };
body_bb, let (start, stop, step) = destructure_range(ctx, iter_val);
orelse_bb,
ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised."
let rangenez = ctx
.builder
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
.unwrap();
ctx.make_assert(
generator,
rangenez,
"ValueError",
"range() arg 3 must not be zero",
[None, None, None],
ctx.current_loc,
);
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
{
ctx.builder.position_at_end(cond_bb);
ctx.builder
.build_conditional_branch(
gen_in_range_check(
ctx,
ctx.builder
.build_load(i, "")
.map(BasicValueEnum::into_int_value)
.unwrap(),
stop,
step,
),
body_bb,
orelse_bb,
)
.unwrap();
}
ctx.builder.position_at_end(incr_bb);
let next_i = ctx
.builder
.build_int_add(
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
step,
"inc",
) )
.unwrap(); .unwrap();
ctx.builder.build_store(i, next_i).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(body_bb);
ctx.builder
.build_store(
target_i,
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
)
.unwrap();
generator.gen_block(ctx, body.iter())?;
} }
TypeEnum::TObj { obj_id, .. }
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
{
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
let len = ctx
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
Some("len"),
)
.into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(incr_bb); ctx.builder.position_at_end(cond_bb);
let next_i = ctx let index = ctx
.builder .builder
.build_int_add( .build_load(index_addr, "for.index")
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), .map(BasicValueEnum::into_int_value)
step, .unwrap();
"inc", let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap();
) ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
.unwrap();
ctx.builder.build_store(i, next_i).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(incr_bb);
ctx.builder let index =
.build_store( ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
target_i, let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_store(index_addr, inc).unwrap();
) ctx.builder.build_unconditional_branch(cond_bb).unwrap();
.unwrap();
generator.gen_block(ctx, body.iter())?;
} else {
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
let len = ctx
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
Some("len"),
)
.into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(cond_bb); ctx.builder.position_at_end(body_bb);
let index = ctx let arr_ptr = ctx
.builder .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.build_load(index_addr, "for.index") .into_pointer_value();
.map(BasicValueEnum::into_int_value) let index = ctx
.unwrap(); .builder
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap(); .build_load(index_addr, "for.index")
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); .map(BasicValueEnum::into_int_value)
.unwrap();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
ctx.builder.position_at_end(incr_bb); generator.gen_assign(ctx, target, val.into())?;
let index = generator.gen_block(ctx, body.iter())?;
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); }
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); _ => {
ctx.builder.build_store(index_addr, inc).unwrap(); panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty));
ctx.builder.build_unconditional_branch(cond_bb).unwrap(); }
ctx.builder.position_at_end(body_bb);
let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.into_pointer_value();
let index = ctx
.builder
.build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value)
.unwrap();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
generator.gen_assign(ctx, target, val.into())?;
generator.gen_block(ctx, body.iter())?;
} }
for (k, (_, _, counter)) in &var_assignment { for (k, (_, _, counter)) in &var_assignment {

View File

@ -100,16 +100,18 @@ pub struct Inferencer<'a> {
pub in_handler: bool, pub in_handler: bool,
} }
type InferenceError = HashSet<String>;
struct NaiveFolder(); struct NaiveFolder();
impl Fold<()> for NaiveFolder { impl Fold<()> for NaiveFolder {
type TargetU = Option<Type>; type TargetU = Option<Type>;
type Error = HashSet<String>; type Error = InferenceError;
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> { fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
Ok(None) Ok(None)
} }
} }
fn report_error<T>(msg: &str, location: Location) -> Result<T, HashSet<String>> { fn report_error<T>(msg: &str, location: Location) -> Result<T, InferenceError> {
Err(HashSet::from([format!("{msg} at {location}")])) Err(HashSet::from([format!("{msg} at {location}")]))
} }
@ -117,13 +119,13 @@ fn report_type_error<T>(
kind: TypeErrorKind, kind: TypeErrorKind,
loc: Option<Location>, loc: Option<Location>,
unifier: &Unifier, unifier: &Unifier,
) -> Result<T, HashSet<String>> { ) -> Result<T, InferenceError> {
Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()])) Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()]))
} }
impl<'a> Fold<()> for Inferencer<'a> { impl<'a> Fold<()> for Inferencer<'a> {
type TargetU = Option<Type>; type TargetU = Option<Type>;
type Error = HashSet<String>; type Error = InferenceError;
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> { fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
Ok(None) Ok(None)
@ -612,22 +614,22 @@ impl<'a> Fold<()> for Inferencer<'a> {
} }
} }
type InferenceResult = Result<Type, HashSet<String>>; type InferenceResult = Result<Type, InferenceError>;
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
/// Constrain a <: b /// Constrain a <: b
/// Currently implemented as unification /// Currently implemented as unification
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> { fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
self.unify(a, b, location) self.unify(a, b, location)
} }
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> { fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
self.unifier.unify(a, b).map_err(|e| { self.unifier.unify(a, b).map_err(|e| {
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()]) HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
}) })
} }
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet<String>> { fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), InferenceError> {
match &pattern.node { match &pattern.node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !self.defined_identifiers.contains(id) { if !self.defined_identifiers.contains(id) {
@ -716,7 +718,7 @@ impl<'a> Inferencer<'a> {
location: Location, location: Location,
args: Arguments, args: Arguments,
body: ast::Expr<()>, body: ast::Expr<()>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> { ) -> Result<ast::Expr<Option<Type>>, InferenceError> {
if !args.posonlyargs.is_empty() if !args.posonlyargs.is_empty()
|| args.vararg.is_some() || args.vararg.is_some()
|| !args.kwonlyargs.is_empty() || !args.kwonlyargs.is_empty()
@ -787,7 +789,7 @@ impl<'a> Inferencer<'a> {
location: Location, location: Location,
elt: ast::Expr<()>, elt: ast::Expr<()>,
mut generators: Vec<Comprehension>, mut generators: Vec<Comprehension>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> { ) -> Result<ast::Expr<Option<Type>>, InferenceError> {
if generators.len() != 1 { if generators.len() != 1 {
return report_error( return report_error(
"Only 1 generator statement for list comprehension is supported", "Only 1 generator statement for list comprehension is supported",
@ -893,7 +895,7 @@ impl<'a> Inferencer<'a> {
id: StrRef, id: StrRef,
arg_index: usize, arg_index: usize,
shape_expr: Located<ExprKind>, shape_expr: Located<ExprKind>,
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> { ) -> Result<(u64, ast::Expr<Option<Type>>), InferenceError> {
/* /*
### Further explanation ### Further explanation
@ -1030,7 +1032,7 @@ impl<'a> Inferencer<'a> {
func: &ast::Expr<()>, func: &ast::Expr<()>,
args: &mut Vec<ast::Expr<()>>, args: &mut Vec<ast::Expr<()>>,
keywords: &[Located<ast::KeywordData>], keywords: &[Located<ast::KeywordData>],
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> { ) -> Result<Option<ast::Expr<Option<Type>>>, InferenceError> {
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
return Ok(None); return Ok(None);
}; };
@ -1588,7 +1590,7 @@ impl<'a> Inferencer<'a> {
func: ast::Expr<()>, func: ast::Expr<()>,
mut args: Vec<ast::Expr<()>>, mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>, keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> { ) -> Result<ast::Expr<Option<Type>>, InferenceError> {
if let Some(spec_call_func) = if let Some(spec_call_func) =
self.try_fold_special_call(location, &func, &mut args, &keywords)? self.try_fold_special_call(location, &func, &mut args, &keywords)?
{ {