diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3300d3c8..d5602bc7 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -19,6 +19,7 @@ use crate::{ use inkwell::{ AddressSpace, attributes::{Attribute, AttributeLoc}, + IntPredicate, types::{AnyType, BasicType, BasicTypeEnum}, values::{BasicValueEnum, FunctionValue, IntValue, PointerValue} }; @@ -924,7 +925,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( let length = ctx.builder.build_int_add(length, int32.const_int(1, false), "add1"); // in case length is non-positive let is_valid = - ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, length, zero_32, "check"); + ctx.builder.build_int_compare(IntPredicate::SGT, length, zero_32, "check"); let normal = ctx.ctx.append_basic_block(current, "listcomp.normal_list"); let empty = ctx.ctx.append_basic_block(current, "listcomp.empty_list"); let list_init = ctx.ctx.append_basic_block(current, "listcomp.list_init"); @@ -964,7 +965,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( ctx.builder.position_at_end(test_bb); let sign = - ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, step, zero_32, "sign"); + ctx.builder.build_int_compare(IntPredicate::SGT, step, zero_32, "sign"); // add and test let tmp = ctx.builder.build_int_add( ctx.builder.build_load(i, "i").into_int_value(), @@ -973,9 +974,9 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( ); ctx.builder.build_store(i, tmp); // if step > 0, continue when i < end - let cmp1 = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, end, "cmp1"); + let cmp1 = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, end, "cmp1"); // if step < 0, continue when i > end - let cmp2 = ctx.builder.build_int_compare(inkwell::IntPredicate::SGT, tmp, end, "cmp2"); + let cmp2 = ctx.builder.build_int_compare(IntPredicate::SGT, tmp, end, "cmp2"); let pos = ctx.builder.build_and(sign, cmp1, "pos"); let neg = ctx.builder.build_and(ctx.builder.build_not(sign, "inv"), cmp2, "neg"); ctx.builder.build_conditional_branch( @@ -1005,7 +1006,7 @@ pub fn gen_comprehension<'ctx, 'a, G: CodeGenerator>( let tmp = ctx.builder.build_load(counter, "i").into_int_value(); let tmp = ctx.builder.build_int_add(tmp, size_t.const_int(1, false), "inc"); ctx.builder.build_store(counter, tmp); - let cmp = ctx.builder.build_int_compare(inkwell::IntPredicate::SLT, tmp, length, "cmp"); + let cmp = ctx.builder.build_int_compare(IntPredicate::SLT, tmp, length, "cmp"); ctx.builder.build_conditional_branch(cmp, body_bb, cont_bb); ctx.builder.position_at_end(body_bb); @@ -1359,7 +1360,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ast::Unaryop::Not => ctx .builder .build_int_compare( - inkwell::IntPredicate::EQ, + IntPredicate::EQ, val, val.get_type().const_zero(), "not", @@ -1392,9 +1393,14 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { let ty = ctx.unifier.get_representative(lhs.custom.unwrap()); let current = - if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.bool] + if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool] .contains(&ty) { + let use_unsigned_ops = [ + ctx.primitives.uint32, + ctx.primitives.uint64, + ].contains(&ty); + let (lhs, rhs) = if let ( BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs), @@ -1412,15 +1418,34 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( } else { unreachable!() }; + let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::IntPredicate::EQ, - ast::Cmpop::NotEq => inkwell::IntPredicate::NE, - ast::Cmpop::Lt => inkwell::IntPredicate::SLT, - ast::Cmpop::LtE => inkwell::IntPredicate::SLE, - ast::Cmpop::Gt => inkwell::IntPredicate::SGT, - ast::Cmpop::GtE => inkwell::IntPredicate::SGE, + ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, + ast::Cmpop::NotEq => IntPredicate::NE, + _ if ty == ctx.primitives.bool => unreachable!(), + ast::Cmpop::Lt => if use_unsigned_ops { + IntPredicate::ULT + } else { + IntPredicate::SLT + }, + ast::Cmpop::LtE => if use_unsigned_ops { + IntPredicate::ULE + } else { + IntPredicate::SLE + }, + ast::Cmpop::Gt => if use_unsigned_ops { + IntPredicate::UGT + } else { + IntPredicate::SGT + }, + ast::Cmpop::GtE => if use_unsigned_ops { + IntPredicate::UGE + } else { + IntPredicate::SGE + }, _ => unreachable!(), }; + ctx.builder.build_int_compare(op, lhs, rhs, "cmp") } else if ty == ctx.primitives.float { let (lhs, rhs) = if let ( @@ -1655,7 +1680,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ctx.builder .build_select( ctx.builder.build_int_compare( - inkwell::IntPredicate::SLT, + IntPredicate::SLT, step, zero, "is_neg", @@ -1696,7 +1721,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( ); // handle negative index let is_negative = ctx.builder.build_int_compare( - inkwell::IntPredicate::SLT, + IntPredicate::SLT, raw_index, generator.get_size_type(ctx.ctx).const_zero(), "is_neg", @@ -1709,7 +1734,7 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( // unsigned less than is enough, because negative index after adjustment is // bigger than the length (for unsigned cmp) let bound_check = ctx.builder.build_int_compare( - inkwell::IntPredicate::ULT, + IntPredicate::ULT, index, len, "inbound", diff --git a/nac3standalone/demo/demo.c b/nac3standalone/demo/demo.c index 4369b704..1c107fc1 100644 --- a/nac3standalone/demo/demo.c +++ b/nac3standalone/demo/demo.c @@ -1,3 +1,4 @@ +#include #include #include #include @@ -11,6 +12,10 @@ #error "Unsupported platform - Platform is not 32-bit or 64-bit" #endif +void output_bool(const bool x) { + puts(x ? "True" : "False"); +} + void output_int32(const int32_t x) { printf("%d\n", x); } @@ -31,7 +36,7 @@ void output_float64(const double x) { printf("%f\n", x); } void output_asciiart(const int32_t x) { - const char* chars = " .,-:;i+hHM$*#@ "; + static const char *chars = " .,-:;i+hHM$*#@ "; if (x < 0) { putchar('\n'); } else { @@ -40,12 +45,12 @@ void output_asciiart(const int32_t x) { } struct cslice { - const void* data; + const void *data; usize len; }; -void output_int32_list(struct cslice* slice) { - const int32_t* data = (const int32_t*) slice->data; +void output_int32_list(struct cslice *slice) { + const int32_t *data = (const int32_t *) slice->data; putchar('['); for (usize i = 0; i < slice->len; ++i) { @@ -59,8 +64,8 @@ void output_int32_list(struct cslice* slice) { putchar('\n'); } -void output_str(struct cslice* slice) { - const char* data = (const char*) slice->data; +void output_str(struct cslice *slice) { + const char *data = (const char *) slice->data; for (usize i = 0; i < slice->len; ++i) { putchar(data[i]); diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 69c0cf05..0a1007f0 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -58,6 +58,7 @@ def patch(module): elif name == "output_float64": return output_float elif name in { + "output_bool", "output_int32", "output_int64", "output_int32_list", diff --git a/nac3standalone/demo/src/demo_test.py b/nac3standalone/demo/src/demo_test.py index f75f8690..acc89bbd 100644 --- a/nac3standalone/demo/src/demo_test.py +++ b/nac3standalone/demo/src/demo_test.py @@ -1,3 +1,7 @@ +@extern +def output_bool(x: bool): + ... + @extern def output_int32(x: int32): ... @@ -30,6 +34,10 @@ def output_asciiart(x: int32): def output_str(x: str): ... +def test_output_bool(): + output_bool(True) + output_bool(False) + def test_output_int32(): output_int32(-128) @@ -63,6 +71,7 @@ def test_output_str_family(): output_str("hello world") def run() -> int32: + test_output_bool() test_output_int32() test_output_int64() test_output_uint32()