forked from M-Labs/nac3
core/codegen: refactor gen_{for,comprehension} to match on iter type
This commit is contained in:
parent
669c6aca6b
commit
894083c6a3
|
@ -995,8 +995,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.position_at_end(init_bb);
|
ctx.builder.position_at_end(init_bb);
|
||||||
|
|
||||||
let Comprehension { target, iter, ifs, .. } = &generators[0];
|
let Comprehension { target, iter, ifs, .. } = &generators[0];
|
||||||
|
|
||||||
|
let iter_ty = iter.custom.unwrap();
|
||||||
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
||||||
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, iter_ty)?
|
||||||
} else {
|
} else {
|
||||||
for bb in [test_bb, body_bb, cont_bb] {
|
for bb in [test_bb, body_bb, cont_bb] {
|
||||||
ctx.builder.position_at_end(bb);
|
ctx.builder.position_at_end(bb);
|
||||||
|
@ -1014,96 +1016,120 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.build_store(index, zero_size_t).unwrap();
|
ctx.builder.build_store(index, zero_size_t).unwrap();
|
||||||
|
|
||||||
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
|
let elem_ty = ctx.get_llvm_type(generator, elt.custom.unwrap());
|
||||||
let is_range = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
|
|
||||||
let list;
|
let list;
|
||||||
|
|
||||||
if is_range {
|
match &*ctx.unifier.get_ty(iter_ty) {
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
TypeEnum::TObj { obj_id, .. }
|
||||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
{
|
||||||
// add 1 to the length as the value is rounded to zero
|
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
||||||
// the length may be 1 more than the actual length if the division is exact, but the
|
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||||
// length is a upper bound only anyway so it does not matter.
|
let diff = ctx.builder.build_int_sub(stop, start, "diff").unwrap();
|
||||||
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
|
// add 1 to the length as the value is rounded to zero
|
||||||
let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
|
// the length may be 1 more than the actual length if the division is exact, but the
|
||||||
// in case length is non-positive
|
// length is a upper bound only anyway so it does not matter.
|
||||||
let is_valid =
|
let length = ctx.builder.build_int_signed_div(diff, step, "div").unwrap();
|
||||||
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
|
let length =
|
||||||
|
ctx.builder.build_int_add(length, int32.const_int(1, false), "add1").unwrap();
|
||||||
|
// in case length is non-positive
|
||||||
|
let is_valid =
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check").unwrap();
|
||||||
|
|
||||||
let list_alloc_size = ctx
|
let list_alloc_size = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_select(
|
.build_select(
|
||||||
is_valid,
|
is_valid,
|
||||||
ctx.builder.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len").unwrap(),
|
ctx.builder
|
||||||
zero_size_t,
|
.build_int_z_extend_or_bit_cast(length, size_t, "z_ext_len")
|
||||||
"listcomp.alloc_size",
|
.unwrap(),
|
||||||
)
|
zero_size_t,
|
||||||
.unwrap();
|
"listcomp.alloc_size",
|
||||||
list = allocate_list(
|
)
|
||||||
generator,
|
.unwrap();
|
||||||
ctx,
|
list = allocate_list(
|
||||||
Some(elem_ty),
|
generator,
|
||||||
list_alloc_size.into_int_value(),
|
ctx,
|
||||||
Some("listcomp.addr"),
|
Some(elem_ty),
|
||||||
);
|
list_alloc_size.into_int_value(),
|
||||||
|
Some("listcomp.addr"),
|
||||||
|
);
|
||||||
|
|
||||||
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
|
let i = generator.gen_store_target(ctx, target, Some("i.addr"))?.unwrap();
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap())
|
.build_store(i, ctx.builder.build_int_sub(start, step, "start_init").unwrap())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_conditional_branch(gen_in_range_check(ctx, start, stop, step), test_bb, cont_bb)
|
.build_conditional_branch(
|
||||||
.unwrap();
|
gen_in_range_check(ctx, start, stop, step),
|
||||||
|
test_bb,
|
||||||
|
cont_bb,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(test_bb);
|
ctx.builder.position_at_end(test_bb);
|
||||||
// add and test
|
// add and test
|
||||||
let tmp = ctx
|
let tmp = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_add(
|
.build_int_add(
|
||||||
ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(),
|
ctx.builder.build_load(i, "i").map(BasicValueEnum::into_int_value).unwrap(),
|
||||||
step,
|
step,
|
||||||
"start_loop",
|
"start_loop",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder.build_store(i, tmp).unwrap();
|
ctx.builder.build_store(i, tmp).unwrap();
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_conditional_branch(gen_in_range_check(ctx, tmp, stop, step), body_bb, cont_bb)
|
.build_conditional_branch(
|
||||||
.unwrap();
|
gen_in_range_check(ctx, tmp, stop, step),
|
||||||
|
body_bb,
|
||||||
|
cont_bb,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
ctx.builder.position_at_end(body_bb);
|
||||||
} else {
|
}
|
||||||
let length = ctx
|
TypeEnum::TObj { obj_id, .. }
|
||||||
.build_gep_and_load(
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
iter_val.into_pointer_value(),
|
{
|
||||||
&[zero_size_t, int32.const_int(1, false)],
|
let length = ctx
|
||||||
Some("length"),
|
.build_gep_and_load(
|
||||||
)
|
iter_val.into_pointer_value(),
|
||||||
.into_int_value();
|
&[zero_size_t, int32.const_int(1, false)],
|
||||||
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp"));
|
Some("length"),
|
||||||
|
)
|
||||||
|
.into_int_value();
|
||||||
|
list = allocate_list(generator, ctx, Some(elem_ty), length, Some("listcomp"));
|
||||||
|
|
||||||
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
|
let counter = generator.gen_var_alloc(ctx, size_t.into(), Some("counter.addr"))?;
|
||||||
// counter = -1
|
// counter = -1
|
||||||
ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap();
|
ctx.builder.build_store(counter, size_t.const_all_ones()).unwrap();
|
||||||
ctx.builder.build_unconditional_branch(test_bb).unwrap();
|
ctx.builder.build_unconditional_branch(test_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(test_bb);
|
ctx.builder.position_at_end(test_bb);
|
||||||
let tmp = ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
|
let tmp =
|
||||||
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
|
ctx.builder.build_load(counter, "i").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
ctx.builder.build_store(counter, tmp).unwrap();
|
let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc").unwrap();
|
||||||
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
|
ctx.builder.build_store(counter, tmp).unwrap();
|
||||||
ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap();
|
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp").unwrap();
|
||||||
|
ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
ctx.builder.position_at_end(body_bb);
|
||||||
let arr_ptr = ctx
|
let arr_ptr = ctx
|
||||||
.build_gep_and_load(
|
.build_gep_and_load(
|
||||||
iter_val.into_pointer_value(),
|
iter_val.into_pointer_value(),
|
||||||
&[zero_size_t, zero_32],
|
&[zero_size_t, zero_32],
|
||||||
Some("arr.addr"),
|
Some("arr.addr"),
|
||||||
)
|
)
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
|
let val = ctx.build_gep_and_load(arr_ptr, &[tmp], Some("val"));
|
||||||
generator.gen_assign(ctx, target, val.into())?;
|
generator.gen_assign(ctx, target, val.into())?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!(
|
||||||
|
"unsupported list comprehension iterator type: {}",
|
||||||
|
ctx.unifier.stringify(iter_ty)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emits the content of `cont_bb`
|
// Emits the content of `cont_bb`
|
||||||
|
|
|
@ -315,9 +315,6 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
let orelse_bb =
|
let orelse_bb =
|
||||||
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
|
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
|
||||||
|
|
||||||
// Whether the iterable is a range() expression
|
|
||||||
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
|
|
||||||
|
|
||||||
// The BB containing the increment expression
|
// The BB containing the increment expression
|
||||||
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
|
let incr_bb = ctx.ctx.append_basic_block(current, "for.incr");
|
||||||
// The BB containing the loop condition check
|
// The BB containing the loop condition check
|
||||||
|
@ -326,113 +323,132 @@ pub fn gen_for<G: CodeGenerator>(
|
||||||
// store loop bb information and restore it later
|
// store loop bb information and restore it later
|
||||||
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
|
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
|
||||||
|
|
||||||
|
let iter_ty = iter.custom.unwrap();
|
||||||
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
|
||||||
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
|
v.to_basic_value_enum(ctx, generator, iter_ty)?
|
||||||
} else {
|
} else {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
};
|
};
|
||||||
if is_iterable_range_expr {
|
|
||||||
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
|
||||||
// Internal variable for loop; Cannot be assigned
|
|
||||||
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
|
||||||
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
|
||||||
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
|
|
||||||
else {
|
|
||||||
unreachable!()
|
|
||||||
};
|
|
||||||
let (start, stop, step) = destructure_range(ctx, iter_val);
|
|
||||||
|
|
||||||
ctx.builder.build_store(i, start).unwrap();
|
|
||||||
|
|
||||||
// Check "If step is zero, ValueError is raised."
|
|
||||||
let rangenez =
|
|
||||||
ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
|
|
||||||
ctx.make_assert(
|
|
||||||
generator,
|
|
||||||
rangenez,
|
|
||||||
"ValueError",
|
|
||||||
"range() arg 3 must not be zero",
|
|
||||||
[None, None, None],
|
|
||||||
ctx.current_loc,
|
|
||||||
);
|
|
||||||
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
|
||||||
|
|
||||||
|
match &*ctx.unifier.get_ty(iter_ty) {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.range.obj_id(&ctx.unifier).unwrap() =>
|
||||||
{
|
{
|
||||||
ctx.builder.position_at_end(cond_bb);
|
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
|
||||||
ctx.builder
|
// Internal variable for loop; Cannot be assigned
|
||||||
.build_conditional_branch(
|
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
|
||||||
gen_in_range_check(
|
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
|
||||||
ctx,
|
let Some(target_i) =
|
||||||
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
generator.gen_store_target(ctx, target, Some("for.target.addr"))?
|
||||||
stop,
|
else {
|
||||||
step,
|
unreachable!()
|
||||||
),
|
};
|
||||||
body_bb,
|
let (start, stop, step) = destructure_range(ctx, iter_val);
|
||||||
orelse_bb,
|
|
||||||
|
ctx.builder.build_store(i, start).unwrap();
|
||||||
|
|
||||||
|
// Check "If step is zero, ValueError is raised."
|
||||||
|
let rangenez = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
rangenez,
|
||||||
|
"ValueError",
|
||||||
|
"range() arg 3 must not be zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
||||||
|
|
||||||
|
{
|
||||||
|
ctx.builder.position_at_end(cond_bb);
|
||||||
|
ctx.builder
|
||||||
|
.build_conditional_branch(
|
||||||
|
gen_in_range_check(
|
||||||
|
ctx,
|
||||||
|
ctx.builder
|
||||||
|
.build_load(i, "")
|
||||||
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap(),
|
||||||
|
stop,
|
||||||
|
step,
|
||||||
|
),
|
||||||
|
body_bb,
|
||||||
|
orelse_bb,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(incr_bb);
|
||||||
|
let next_i = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(
|
||||||
|
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
||||||
|
step,
|
||||||
|
"inc",
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
ctx.builder.build_store(i, next_i).unwrap();
|
||||||
|
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(body_bb);
|
||||||
|
ctx.builder
|
||||||
|
.build_store(
|
||||||
|
target_i,
|
||||||
|
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
generator.gen_block(ctx, body.iter())?;
|
||||||
}
|
}
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == ctx.primitives.list.obj_id(&ctx.unifier).unwrap() =>
|
||||||
|
{
|
||||||
|
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
|
||||||
|
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
|
||||||
|
let len = ctx
|
||||||
|
.build_gep_and_load(
|
||||||
|
iter_val.into_pointer_value(),
|
||||||
|
&[zero, int32.const_int(1, false)],
|
||||||
|
Some("len"),
|
||||||
|
)
|
||||||
|
.into_int_value();
|
||||||
|
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
||||||
|
|
||||||
ctx.builder.position_at_end(incr_bb);
|
ctx.builder.position_at_end(cond_bb);
|
||||||
let next_i = ctx
|
let index = ctx
|
||||||
.builder
|
.builder
|
||||||
.build_int_add(
|
.build_load(index_addr, "for.index")
|
||||||
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
.map(BasicValueEnum::into_int_value)
|
||||||
step,
|
.unwrap();
|
||||||
"inc",
|
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap();
|
||||||
)
|
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
|
||||||
.unwrap();
|
|
||||||
ctx.builder.build_store(i, next_i).unwrap();
|
|
||||||
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
ctx.builder.position_at_end(incr_bb);
|
||||||
ctx.builder
|
let index =
|
||||||
.build_store(
|
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
||||||
target_i,
|
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
|
||||||
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
|
ctx.builder.build_store(index_addr, inc).unwrap();
|
||||||
)
|
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
||||||
.unwrap();
|
|
||||||
generator.gen_block(ctx, body.iter())?;
|
|
||||||
} else {
|
|
||||||
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
|
|
||||||
ctx.builder.build_store(index_addr, size_t.const_zero()).unwrap();
|
|
||||||
let len = ctx
|
|
||||||
.build_gep_and_load(
|
|
||||||
iter_val.into_pointer_value(),
|
|
||||||
&[zero, int32.const_int(1, false)],
|
|
||||||
Some("len"),
|
|
||||||
)
|
|
||||||
.into_int_value();
|
|
||||||
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
|
||||||
|
|
||||||
ctx.builder.position_at_end(cond_bb);
|
ctx.builder.position_at_end(body_bb);
|
||||||
let index = ctx
|
let arr_ptr = ctx
|
||||||
.builder
|
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
|
||||||
.build_load(index_addr, "for.index")
|
.into_pointer_value();
|
||||||
.map(BasicValueEnum::into_int_value)
|
let index = ctx
|
||||||
.unwrap();
|
.builder
|
||||||
let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, index, len, "cond").unwrap();
|
.build_load(index_addr, "for.index")
|
||||||
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
|
.map(BasicValueEnum::into_int_value)
|
||||||
|
.unwrap();
|
||||||
|
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
|
||||||
|
|
||||||
ctx.builder.position_at_end(incr_bb);
|
generator.gen_assign(ctx, target, val.into())?;
|
||||||
let index =
|
generator.gen_block(ctx, body.iter())?;
|
||||||
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
|
}
|
||||||
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
|
_ => {
|
||||||
ctx.builder.build_store(index_addr, inc).unwrap();
|
panic!("unsupported for loop iterator type: {}", ctx.unifier.stringify(iter_ty));
|
||||||
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
|
}
|
||||||
|
|
||||||
ctx.builder.position_at_end(body_bb);
|
|
||||||
let arr_ptr = ctx
|
|
||||||
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
|
|
||||||
.into_pointer_value();
|
|
||||||
let index = ctx
|
|
||||||
.builder
|
|
||||||
.build_load(index_addr, "for.index")
|
|
||||||
.map(BasicValueEnum::into_int_value)
|
|
||||||
.unwrap();
|
|
||||||
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
|
|
||||||
generator.gen_assign(ctx, target, val.into())?;
|
|
||||||
generator.gen_block(ctx, body.iter())?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (k, (_, _, counter)) in &var_assignment {
|
for (k, (_, _, counter)) in &var_assignment {
|
||||||
|
|
|
@ -100,16 +100,18 @@ pub struct Inferencer<'a> {
|
||||||
pub in_handler: bool,
|
pub in_handler: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InferenceError = HashSet<String>;
|
||||||
|
|
||||||
struct NaiveFolder();
|
struct NaiveFolder();
|
||||||
impl Fold<()> for NaiveFolder {
|
impl Fold<()> for NaiveFolder {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = HashSet<String>;
|
type Error = InferenceError;
|
||||||
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn report_error<T>(msg: &str, location: Location) -> Result<T, HashSet<String>> {
|
fn report_error<T>(msg: &str, location: Location) -> Result<T, InferenceError> {
|
||||||
Err(HashSet::from([format!("{msg} at {location}")]))
|
Err(HashSet::from([format!("{msg} at {location}")]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,13 +119,13 @@ fn report_type_error<T>(
|
||||||
kind: TypeErrorKind,
|
kind: TypeErrorKind,
|
||||||
loc: Option<Location>,
|
loc: Option<Location>,
|
||||||
unifier: &Unifier,
|
unifier: &Unifier,
|
||||||
) -> Result<T, HashSet<String>> {
|
) -> Result<T, InferenceError> {
|
||||||
Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()]))
|
Err(HashSet::from([TypeError::new(kind, loc).to_display(unifier).to_string()]))
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Fold<()> for Inferencer<'a> {
|
impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = HashSet<String>;
|
type Error = InferenceError;
|
||||||
|
|
||||||
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
fn map_user(&mut self, (): ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
|
@ -612,22 +614,22 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type InferenceResult = Result<Type, HashSet<String>>;
|
type InferenceResult = Result<Type, InferenceError>;
|
||||||
|
|
||||||
impl<'a> Inferencer<'a> {
|
impl<'a> Inferencer<'a> {
|
||||||
/// Constrain a <: b
|
/// Constrain a <: b
|
||||||
/// Currently implemented as unification
|
/// Currently implemented as unification
|
||||||
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> {
|
fn constrain(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
|
||||||
self.unify(a, b, location)
|
self.unify(a, b, location)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> {
|
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), InferenceError> {
|
||||||
self.unifier.unify(a, b).map_err(|e| {
|
self.unifier.unify(a, b).map_err(|e| {
|
||||||
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
|
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet<String>> {
|
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), InferenceError> {
|
||||||
match &pattern.node {
|
match &pattern.node {
|
||||||
ExprKind::Name { id, .. } => {
|
ExprKind::Name { id, .. } => {
|
||||||
if !self.defined_identifiers.contains(id) {
|
if !self.defined_identifiers.contains(id) {
|
||||||
|
@ -716,7 +718,7 @@ impl<'a> Inferencer<'a> {
|
||||||
location: Location,
|
location: Location,
|
||||||
args: Arguments,
|
args: Arguments,
|
||||||
body: ast::Expr<()>,
|
body: ast::Expr<()>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if !args.posonlyargs.is_empty()
|
if !args.posonlyargs.is_empty()
|
||||||
|| args.vararg.is_some()
|
|| args.vararg.is_some()
|
||||||
|| !args.kwonlyargs.is_empty()
|
|| !args.kwonlyargs.is_empty()
|
||||||
|
@ -787,7 +789,7 @@ impl<'a> Inferencer<'a> {
|
||||||
location: Location,
|
location: Location,
|
||||||
elt: ast::Expr<()>,
|
elt: ast::Expr<()>,
|
||||||
mut generators: Vec<Comprehension>,
|
mut generators: Vec<Comprehension>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if generators.len() != 1 {
|
if generators.len() != 1 {
|
||||||
return report_error(
|
return report_error(
|
||||||
"Only 1 generator statement for list comprehension is supported",
|
"Only 1 generator statement for list comprehension is supported",
|
||||||
|
@ -893,7 +895,7 @@ impl<'a> Inferencer<'a> {
|
||||||
id: StrRef,
|
id: StrRef,
|
||||||
arg_index: usize,
|
arg_index: usize,
|
||||||
shape_expr: Located<ExprKind>,
|
shape_expr: Located<ExprKind>,
|
||||||
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
|
) -> Result<(u64, ast::Expr<Option<Type>>), InferenceError> {
|
||||||
/*
|
/*
|
||||||
### Further explanation
|
### Further explanation
|
||||||
|
|
||||||
|
@ -1030,7 +1032,7 @@ impl<'a> Inferencer<'a> {
|
||||||
func: &ast::Expr<()>,
|
func: &ast::Expr<()>,
|
||||||
args: &mut Vec<ast::Expr<()>>,
|
args: &mut Vec<ast::Expr<()>>,
|
||||||
keywords: &[Located<ast::KeywordData>],
|
keywords: &[Located<ast::KeywordData>],
|
||||||
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
|
) -> Result<Option<ast::Expr<Option<Type>>>, InferenceError> {
|
||||||
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
};
|
};
|
||||||
|
@ -1588,7 +1590,7 @@ impl<'a> Inferencer<'a> {
|
||||||
func: ast::Expr<()>,
|
func: ast::Expr<()>,
|
||||||
mut args: Vec<ast::Expr<()>>,
|
mut args: Vec<ast::Expr<()>>,
|
||||||
keywords: Vec<Located<ast::KeywordData>>,
|
keywords: Vec<Located<ast::KeywordData>>,
|
||||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
) -> Result<ast::Expr<Option<Type>>, InferenceError> {
|
||||||
if let Some(spec_call_func) =
|
if let Some(spec_call_func) =
|
||||||
self.try_fold_special_call(location, &func, &mut args, &keywords)?
|
self.try_fold_special_call(location, &func, &mut args, &keywords)?
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue