[core] codegen/expr: Implement comparison of tuples

This commit is contained in:
David Mak 2024-08-20 19:49:52 +08:00
parent c6e3ecaeb9
commit 1ab9a118d6
1 changed files with 107 additions and 0 deletions

View File

@ -40,6 +40,7 @@ use nac3parser::ast::{
self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
Unaryop, Unaryop,
}; };
use std::cmp::min;
use std::iter::{repeat, repeat_with}; use std::iter::{repeat, repeat_with};
use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
@ -2207,6 +2208,112 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
} }
ctx.ctx.bool_type().get_poison() ctx.ctx.bool_type().get_poison()
} else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) {
let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
};
let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
};
let llvm_i1 = ctx.ctx.bool_type();
let llvm_i32 = ctx.ctx.i32_type();
// Assume `true` by default, similar to reduce-like operations
let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap();
ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap();
let current_bb = ctx.builder.get_insert_block().unwrap();
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
// Generate comparison between each element:
// for i in 0..min(lhs.len(), rhs.len()) {
// if !(lhs op rhs) {
// cmp = false;
// goto end;
// }
// }
let min_len = min(left_tys.len(), right_tys.len());
for i in 0..min_len {
let current_bb = ctx.builder.get_insert_block().unwrap();
let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e"));
ctx.builder.build_unconditional_branch(bb).unwrap();
ctx.builder.position_at_end(bb);
let left_ty = left_tys[i];
let left_elem = ctx.build_in_bounds_gep_and_load(
lhs.into_pointer_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
None,
);
let right_ty = right_tys[i];
let right_elem = ctx.build_in_bounds_gep_and_load(
rhs.into_pointer_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)],
None,
);
gen_if_callback(
generator,
ctx,
|generator, ctx| {
gen_cmpop_expr_with_values(
generator,
ctx,
(Some(left_ty), left_elem),
&[*op],
&[(Some(right_ty), right_elem)],
)
.transpose()
.unwrap()
.and_then(|v| v.to_basic_value_enum(ctx, generator, ctx.primitives.bool))
.map(BasicValueEnum::into_int_value)
},
|_, ctx| {
ctx.builder.build_store(cmp_addr, llvm_i1.const_zero()).unwrap();
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
Ok(())
},
|_, _| Ok(()),
)?;
}
// All elements checked - Finally check length of tuples
gen_if_callback(
generator,
ctx,
|_, ctx| {
let cmp = ctx.builder.build_load(cmp_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let llvm_cmpop = match op {
Cmpop::Eq => IntPredicate::EQ,
Cmpop::NotEq => IntPredicate::NE,
Cmpop::Lt => IntPredicate::ULT,
Cmpop::LtE => IntPredicate::ULE,
Cmpop::Gt => IntPredicate::UGT,
Cmpop::GtE => IntPredicate::UGE,
_ => unreachable!(),
};
Ok(ctx.builder.build_int_compare(
llvm_cmpop,
cmp,
llvm_i1.const_all_ones(),
"",
).unwrap())
},
|_, ctx| {
ctx.builder.build_store(cmp_addr, llvm_i1.const_zero()).unwrap();
Ok(())
},
|_, _| Ok(()),
)?;
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
ctx.builder.position_at_end(post_foreach_cmp);
ctx.builder.build_load(cmp_addr, "").map(BasicValueEnum::into_int_value).unwrap()
} else { } else {
return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty)))
}; };