forked from M-Labs/nac3
[core] codegen/expr: Implement comparison of tuples
This commit is contained in:
parent
33929bda24
commit
4d80ba38b7
@ -40,6 +40,7 @@ use nac3parser::ast::{
|
||||
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
Unaryop,
|
||||
};
|
||||
use std::cmp::min;
|
||||
use std::iter::{repeat, repeat_with};
|
||||
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
|
||||
|
||||
@ -2303,6 +2304,119 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
};
|
||||
|
||||
gen_list_cmpop(generator, ctx)?
|
||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
|
||||
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
|
||||
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||
};
|
||||
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
|
||||
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
|
||||
};
|
||||
|
||||
if ![Cmpop::Eq, Cmpop::NotEq].contains(op) {
|
||||
todo!("Only __eq__ and __ne__ is implemented for tuples")
|
||||
}
|
||||
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
|
||||
// Assume `true` by default
|
||||
let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap();
|
||||
ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap();
|
||||
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
||||
ctx.builder.position_at_end(current_bb);
|
||||
|
||||
// Generate comparison between each element
|
||||
let min_len = min(left_tys.len(), right_tys.len());
|
||||
for i in 0..min_len {
|
||||
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||
let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e"));
|
||||
ctx.builder.build_unconditional_branch(bb).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(bb);
|
||||
let left_ty = left_tys[i];
|
||||
let left_elem = {
|
||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type(), None).unwrap();
|
||||
ctx.builder.build_store(plhs, *lhs).unwrap();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
plhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
|
||||
None,
|
||||
)
|
||||
};
|
||||
let right_ty = right_tys[i];
|
||||
let right_elem = {
|
||||
let prhs = generator.gen_var_alloc(ctx, rhs.get_type(), None).unwrap();
|
||||
ctx.builder.build_store(prhs, *rhs).unwrap();
|
||||
|
||||
ctx.build_in_bounds_gep_and_load(
|
||||
prhs,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
|
||||
None,
|
||||
)
|
||||
};
|
||||
|
||||
gen_if_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
// Defer the `not` operation until the end - a != b <=> !(a == b)
|
||||
let op = if *op == Cmpop::NotEq { Cmpop::Eq } else { *op };
|
||||
|
||||
let cmp = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(left_ty), left_elem),
|
||||
&[op],
|
||||
&[(Some(right_ty), right_elem)],
|
||||
)
|
||||
.transpose()
|
||||
.unwrap()
|
||||
.and_then(|v| {
|
||||
v.to_basic_value_enum(ctx, generator, ctx.primitives.bool)
|
||||
})
|
||||
.map(BasicValueEnum::into_int_value)?;
|
||||
|
||||
Ok(ctx.builder.build_not(cmp, "").unwrap())
|
||||
},
|
||||
|_, ctx| {
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, _| Ok(()),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Length of tuples is checked last as operators do not short-circuit by tuple
|
||||
// length in Python:
|
||||
//
|
||||
// >>> (1, 2) < ("a",)
|
||||
// TypeError: '<' not supported between instances of 'int' and 'str'
|
||||
let bb = ctx.builder.get_insert_block().unwrap();
|
||||
let is_len_eq = llvm_i1.const_int(
|
||||
u64::from(left_tys.len() == right_tys.len()),
|
||||
false,
|
||||
);
|
||||
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
||||
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||
|
||||
ctx.builder.position_at_end(post_foreach_cmp);
|
||||
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
||||
|
||||
// Invert the final value if __ne__
|
||||
if *op == Cmpop::NotEq {
|
||||
ctx.builder.build_not(cmp_phi, "").unwrap()
|
||||
} else {
|
||||
cmp_phi
|
||||
}
|
||||
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) {
|
||||
if ctx.registry.llvm_options.opt_level == OptimizationLevel::None {
|
||||
ctx.make_assert(
|
||||
|
Loading…
Reference in New Issue
Block a user