diff --git a/nac3core/Cargo.toml b/nac3core/Cargo.toml index a25763bd..a3803487 100644 --- a/nac3core/Cargo.toml +++ b/nac3core/Cargo.toml @@ -14,8 +14,8 @@ indexmap = "2.2" parking_lot = "0.12" rayon = "1.8" nac3parser = { path = "../nac3parser" } -strum = "0.26.2" -strum_macros = "0.26.4" +strum = "0.26" +strum_macros = "0.26" [dependencies.inkwell] version = "0.4" diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 0817fce2..50d70eb7 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -9,7 +9,7 @@ use crate::{ irrt::*, llvm_intrinsics::{ call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, - call_memcpy_generic, + call_int_umin, call_memcpy_generic, }, need_sret, numpy, stmt::{ @@ -40,6 +40,7 @@ use nac3parser::ast::{ self, Boolop, Cmpop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef, Unaryop, }; +use std::cmp::min; use std::iter::{repeat, repeat_with}; use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip}; @@ -2024,6 +2025,115 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( _ => unreachable!(), }; ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() + } else if left_ty == ctx.primitives.str { + assert!(ctx.unifier.unioned(left_ty, right_ty)); + + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let lhs = lhs.into_struct_value(); + let rhs = rhs.into_struct_value(); + + let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); + ctx.builder.build_store(plhs, lhs).unwrap(); + let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap(); + ctx.builder.build_store(prhs, rhs).unwrap(); + + let lhs_len = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + None, + ).into_int_value(); + let rhs_len = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, false)], + None, + ).into_int_value(); + + let len = call_int_umin(ctx, lhs_len, rhs_len, None); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); + ctx.builder.position_at_end(current_bb); + + gen_for_callback_incrementing( + generator, + ctx, + None, + llvm_usize.const_zero(), + (len, false), + |generator, ctx, _, i| { + let lhs_char = { + let plhs_data = ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); + + ctx.build_in_bounds_gep_and_load( + plhs_data, + &[i], + None + ).into_int_value() + }; + let rhs_char = { + let prhs_data = ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_pointer_value(); + + ctx.build_in_bounds_gep_and_load( + prhs_data, + &[i], + None + ).into_int_value() + }; + + gen_if_callback( + generator, + ctx, + |_, ctx| { + Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap()) + }, + |_, ctx| { + let bb = ctx.builder.get_insert_block().unwrap(); + cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + )?; + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + let bb = ctx.builder.get_insert_block().unwrap(); + let is_len_eq = ctx.builder.build_int_compare( + IntPredicate::EQ, + lhs_len, + rhs_len, + "", + ).unwrap(); + cmp_phi.add_incoming(&[(&is_len_eq, bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = cmp_phi.as_basic_value().into_int_value(); + + // Invert the final value if __ne__ + if *op == Cmpop::NotEq { + ctx.builder.build_not(cmp_phi, "").unwrap() + } else { + cmp_phi + } } else if [left_ty, right_ty] .iter() .any(|ty| ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id())) @@ -2194,8 +2304,121 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( }; gen_list_cmpop(generator, ctx)? + } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TTuple { .. })) { + let TypeEnum::TTuple { ty: left_tys, .. } = &*ctx.unifier.get_ty_immutable(left_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + let TypeEnum::TTuple { ty: right_tys, .. } = &*ctx.unifier.get_ty_immutable(right_ty) else { + return Err(format!("'{}' not supported between instances of '{}' and '{}'", op.op_info().symbol, ctx.unifier.stringify(left_ty), ctx.unifier.stringify(right_ty))) + }; + + if ![Cmpop::Eq, Cmpop::NotEq].contains(op) { + todo!("Only __eq__ and __ne__ is implemented for tuples") + } + + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i32 = ctx.ctx.i32_type(); + + // Assume `true` by default + let cmp_addr = generator.gen_var_alloc(ctx, llvm_i1.into(), None).unwrap(); + ctx.builder.build_store(cmp_addr, llvm_i1.const_all_ones()).unwrap(); + + let current_bb = ctx.builder.get_insert_block().unwrap(); + let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end"); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap(); + ctx.builder.position_at_end(current_bb); + + // Generate comparison between each element + let min_len = min(left_tys.len(), right_tys.len()); + for i in 0..min_len { + let current_bb = ctx.builder.get_insert_block().unwrap(); + let bb = ctx.ctx.insert_basic_block_after(current_bb, &format!("foreach.cmp.tuple.{i}e")); + ctx.builder.build_unconditional_branch(bb).unwrap(); + + ctx.builder.position_at_end(bb); + let left_ty = left_tys[i]; + let left_elem = { + let plhs = generator.gen_var_alloc(ctx, lhs.get_type(), None).unwrap(); + ctx.builder.build_store(plhs, *lhs).unwrap(); + + ctx.build_in_bounds_gep_and_load( + plhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ) + }; + let right_ty = right_tys[i]; + let right_elem = { + let prhs = generator.gen_var_alloc(ctx, rhs.get_type(), None).unwrap(); + ctx.builder.build_store(prhs, *rhs).unwrap(); + + ctx.build_in_bounds_gep_and_load( + prhs, + &[llvm_i32.const_zero(), llvm_i32.const_int(i as u64, false)], + None, + ) + }; + + gen_if_callback( + generator, + ctx, + |generator, ctx| { + // Defer the `not` operation until the end - a != b <=> !(a == b) + let op = if *op == Cmpop::NotEq { Cmpop::Eq } else { *op }; + + let cmp = gen_cmpop_expr_with_values( + generator, + ctx, + (Some(left_ty), left_elem), + &[op], + &[(Some(right_ty), right_elem)], + ) + .transpose() + .unwrap() + .and_then(|v| { + v.to_basic_value_enum(ctx, generator, ctx.primitives.bool) + }) + .map(BasicValueEnum::into_int_value)?; + + Ok(ctx.builder.build_not(cmp, "").unwrap()) + }, + |_, ctx| { + let bb = ctx.builder.get_insert_block().unwrap(); + cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + Ok(()) + }, + |_, _| Ok(()), + )?; + } + + // Length of tuples is checked last as operators do not short-circuit by tuple + // length in Python: + // + // >>> (1, 2) < ("a",) + // TypeError: '<' not supported between instances of 'int' and 'str' + let bb = ctx.builder.get_insert_block().unwrap(); + let is_len_eq = llvm_i1.const_int( + u64::from(left_tys.len() == right_tys.len()), + false, + ); + cmp_phi.add_incoming(&[(&is_len_eq, bb)]); + ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap(); + + ctx.builder.position_at_end(post_foreach_cmp); + let cmp_phi = cmp_phi.as_basic_value().into_int_value(); + + // Invert the final value if __ne__ + if *op == Cmpop::NotEq { + ctx.builder.build_not(cmp_phi, "").unwrap() + } else { + cmp_phi + } } else if [left_ty, right_ty].iter().any(|ty| matches!(&*ctx.unifier.get_ty_immutable(*ty), TypeEnum::TVar { .. })) { - if ctx.registry.llvm_options.opt_level != OptimizationLevel::None { + if ctx.registry.llvm_options.opt_level == OptimizationLevel::None { ctx.make_assert( generator, ctx.ctx.bool_type().const_all_ones(), @@ -2208,7 +2431,10 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( ctx.ctx.bool_type().get_poison() } else { - unimplemented!() + return Err(format!("'{}' not supported between instances of '{}' and '{}'", + op.op_info().symbol, + ctx.unifier.stringify(left_ty), + ctx.unifier.stringify(right_ty))) }; Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index e0e89c30..6a6b46dd 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -59,7 +59,7 @@ pub trait CodeGenerator { /// function is a class method. /// /// Note that this function should check if the function is generated in another thread (due to - /// possible race condition), see the default implementation for an example. + /// possible race condition), see the default implementation for an example. fn gen_func_instance<'ctx>( &mut self, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 53b3aba6..91e62e94 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -568,7 +568,8 @@ pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> Flo /// /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, -/// or [`None`] if starting from the first dimension and ending at the last dimension respectively. +/// or [`None`] if starting from the first dimension and ending at the last dimension +/// respectively. pub fn call_ndarray_calc_size<'ctx, G, Dims>( generator: &G, ctx: &CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/llvm_intrinsics.rs b/nac3core/src/codegen/llvm_intrinsics.rs index ade0f917..c4a0d430 100644 --- a/nac3core/src/codegen/llvm_intrinsics.rs +++ b/nac3core/src/codegen/llvm_intrinsics.rs @@ -206,7 +206,8 @@ pub fn call_memcpy_generic<'ctx>( /// * `$name:ident`: Optional name to be assigned to the llvm build call (Option<&str>) /// * `$llvm_name:literal`: Name of underlying llvm intrinsic function /// * `$map_fn:ident`: Mapping function to be applied on `BasicValue` (`BasicValue` -> Function Return Type). -/// Use `BasicValueEnum::into_int_value` for Integer return type and `BasicValueEnum::into_float_value` for Float return type +/// Use `BasicValueEnum::into_int_value` for Integer return type and +/// `BasicValueEnum::into_float_value` for Float return type /// * `$llvm_ty:ident`: Type of first operand /// * `,($val:ident)*`: Comma separated list of operands macro_rules! generate_llvm_intrinsic_fn_body { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index f46b50d9..71a2d52a 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -580,11 +580,11 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>( ) -> BasicTypeEnum<'ctx> { // If the type is used in the definition of a function, return `i1` instead of `i8` for ABI // consistency. - return if unifier.unioned(ty, primitives.bool) { + if unifier.unioned(ty, primitives.bool) { ctx.bool_type().into() } else { get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) - }; + } } /// Whether `sret` is needed for a return value with type `ty`. diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 6c526928..d58b566b 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -2144,7 +2144,8 @@ pub fn ndarray_transpose<'ctx, G: CodeGenerator + ?Sized>( /// 1. A list of `int32`; e.g., `np.reshape(arr, [600, -1, 3])` /// 2. A tuple of `int32`; e.g., `np.reshape(arr, (-1, 800, 3))` /// 3. A scalar `int32`; e.g., `np.reshape(arr, 3)` -/// Note that unlike other generating functions, one of the dimesions in the shape can be negative +/// +/// Note that unlike other generating functions, one of the dimensions in the shape can be negative. pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index b5a0608c..325f837a 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -680,6 +680,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, uint32: uint32_t, uint64: uint64_t, + str: str_t, list: list_t, ndarray: ndarray_t, .. @@ -725,6 +726,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_sign(unifier, store, bool_t, Some(int32_t)); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); + /* str ========= */ + impl_cmpop(unifier, store, str_t, &[str_t], &[Cmpop::Eq, Cmpop::NotEq], Some(bool_t)); + /* list ======== */ impl_binop(unifier, store, list_t, &[list_t], Some(list_t), &[Operator::Add]); impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 99a282f2..93ccd9fb 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -1,3 +1,11 @@ +use super::magic_methods::{Binop, HasOpInfo}; +use super::type_error::{TypeError, TypeErrorKind}; +use super::unification_table::{UnificationKey, UnificationTable}; +use crate::symbol_resolver::SymbolValue; +use crate::toplevel::helper::PrimDef; +use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; +use crate::typecheck::magic_methods::OpInfo; +use crate::typecheck::type_inferencer::PrimitiveStore; use indexmap::IndexMap; use itertools::{repeat_n, Itertools}; use nac3parser::ast::{Cmpop, Location, StrRef, Unaryop}; @@ -9,15 +17,6 @@ use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::{borrow::Cow, collections::HashSet}; -use super::magic_methods::Binop; -use super::type_error::{TypeError, TypeErrorKind}; -use super::unification_table::{UnificationKey, UnificationTable}; -use crate::symbol_resolver::SymbolValue; -use crate::toplevel::helper::PrimDef; -use crate::toplevel::{DefinitionId, TopLevelContext, TopLevelDef}; -use crate::typecheck::magic_methods::OpInfo; -use crate::typecheck::type_inferencer::PrimitiveStore; - #[cfg(test)] mod test; @@ -1008,8 +1007,18 @@ impl Unifier { self.unify_impl(v.ty, ty[ind as usize], false) .map_err(|e| e.at(v.loc))?; } - RecordKey::Str(_) => { - return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) + RecordKey::Str(s) => { + let tuple_fns = [ + Cmpop::Eq.op_info().method_name, + Cmpop::NotEq.op_info().method_name, + ]; + + if !tuple_fns.into_iter().any(|op| s.to_string() == op) { + return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), + v.loc, + )); + } } } } diff --git a/nac3standalone/demo/src/str.py b/nac3standalone/demo/src/str.py new file mode 100644 index 00000000..d28f05e2 --- /dev/null +++ b/nac3standalone/demo/src/str.py @@ -0,0 +1,30 @@ +@extern +def output_bool(x: bool): + ... + + +def str_eq(): + output_bool("" == "") + output_bool("a" == "") + output_bool("a" == "b") + output_bool("b" == "a") + output_bool("a" == "a") + output_bool("test string" == "test string") + output_bool("test string1" == "test string2") + + +def str_ne(): + output_bool("" != "") + output_bool("a" != "") + output_bool("a" != "b") + output_bool("b" != "a") + output_bool("a" != "a") + output_bool("test string" != "test string") + output_bool("test string1" != "test string2") + + +def run() -> int32: + str_eq() + str_ne() + + return 0 diff --git a/nac3standalone/demo/src/tuple.py b/nac3standalone/demo/src/tuple.py index 6947171d..988739f7 100644 --- a/nac3standalone/demo/src/tuple.py +++ b/nac3standalone/demo/src/tuple.py @@ -1,3 +1,7 @@ +@extern +def output_bool(b: bool): + ... + @extern def output_int32_list(x: list[int32]): ... @@ -13,6 +17,41 @@ class A: self.a = a self.b = b + +def test_tuple_eq(): + # 0-len + output_bool(() == ()) + # 1-len + output_bool((1,) == ()) + output_bool(() == (1,)) + output_bool((1,) == (1,)) + output_bool((1,) == (2,)) + # # 2-len + output_bool((1, 2) == ()) + output_bool(() == (1, 2)) + output_bool((1,) == (1, 2)) + output_bool((1, 2) == (1,)) + output_bool((2, 2) == (1, 2)) + output_bool((1, 2) == (2, 2)) + + +def test_tuple_ne(): + # 0-len + output_bool(() != ()) + # 1-len + output_bool((1,) != ()) + output_bool(() != (1,)) + output_bool((1,) != (1,)) + output_bool((1,) != (2,)) + # 2-len + output_bool((1, 2) != ()) + output_bool(() != (1, 2)) + output_bool((1,) != (1, 2)) + output_bool((1, 2) != (1,)) + output_bool((2, 2) != (1, 2)) + output_bool((1, 2) != (2, 2)) + + def run() -> int32: data = [0, 1, 2, 3] @@ -33,4 +72,7 @@ def run() -> int32: output_int32(len((1, 2, 3, 4))) output_int32(len((1, 2, 3, 4, 5))) + test_tuple_eq() + test_tuple_ne() + return 0 \ No newline at end of file