diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 38cdd1c3a..112956704 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -925,7 +925,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, ], ret: num_ty.0, - vars: var_map, + vars: var_map.clone(), })), var_id: Default::default(), instance_to_symbol: Default::default(), @@ -978,6 +978,68 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "max".into(), + simple_name: "max".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "m".into(), ty: num_ty.0, default_value: None }, + FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }, + ], + ret: num_ty.0, + vars: var_map, + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + let boolean = ctx.primitives.bool; + let int32 = ctx.primitives.int32; + let int64 = ctx.primitives.int64; + let uint32 = ctx.primitives.uint32; + let uint64 = ctx.primitives.uint64; + let float = ctx.primitives.float; + let llvm_i1 = ctx.ctx.bool_type().as_basic_type_enum(); + let llvm_i32 = ctx.ctx.i32_type().as_basic_type_enum(); + let llvm_i64 = ctx.ctx.i64_type().as_basic_type_enum(); + let llvm_f64 = ctx.ctx.f64_type().as_basic_type_enum(); + let m_ty = fun.0.args[0].ty; + let n_ty = fun.0.args[1].ty; + let m_val = args[0].1.clone().to_basic_value_enum(ctx, generator)?; + let n_val = args[1].1.clone().to_basic_value_enum(ctx, generator)?; + let mut is_type = |a: Type, b: Type| ctx.unifier.unioned(a, b); + let (fun_name, arg_ty) = if is_type(m_ty, n_ty) && is_type(n_ty, boolean) { + ("llvm.umax.i1", llvm_i1) + } else if is_type(m_ty, n_ty) && is_type(n_ty, int32) { + ("llvm.smax.i32", llvm_i32) + } else if is_type(m_ty, n_ty) && is_type(n_ty, int64) { + ("llvm.smax.i64", llvm_i64) + } else if is_type(m_ty, n_ty) && is_type(n_ty, uint32) { + ("llvm.umax.i32", llvm_i32) + } else if is_type(m_ty, n_ty) && is_type(n_ty, uint64) { + ("llvm.umax.i64", llvm_i64) + } else if is_type(m_ty, n_ty) && is_type(n_ty, float) { + ("llvm.maxnum.f64", llvm_f64) + } else { + unreachable!(); + }; + let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| { + let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false); + ctx.module.add_function(fun_name, fn_type, None) + }); + let val = ctx + .builder + .build_call(intrinsic, &[m_val.into(), n_val.into()], "max") + .try_as_basic_value() + .left() + .unwrap(); + Ok(val.into()) + }, + )))), + loc: None, + })), ]; let ast_list: Vec>> = (0..top_level_def_list.len()).map(|_| None).collect(); @@ -1002,6 +1064,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "ceil64", "len", "min", + "max", ], ) }