diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index e13cb958..32230729 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -41,7 +41,8 @@ impl TopLevelComposer { let int32 = primitives.0.int32; let int64 = primitives.0.int64; let float = primitives.0.float; - let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float]); + let boolean = primitives.0.bool; + let num_ty = primitives.1.get_fresh_var_with_range(&[int32, int64, float, boolean]); let var_map: HashMap<_, _> = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let mut definition_ast_list = { @@ -83,9 +84,20 @@ impl TopLevelComposer { let int32 = ctx.primitives.int32; let int64 = ctx.primitives.int64; let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; let arg_ty = fun.0.args[0].ty; let arg = args[0].1; - if ctx.unifier.unioned(arg_ty, int32) { + if ctx.unifier.unioned(arg_ty, boolean) { + Some( + ctx.builder + .build_int_s_extend( + arg.into_int_value(), + ctx.ctx.i32_type(), + "sext", + ) + .into(), + ) + } else if ctx.unifier.unioned(arg_ty, int32) { Some(arg) } else if ctx.unifier.unioned(arg_ty, int64) { Some( @@ -130,9 +142,12 @@ impl TopLevelComposer { let int32 = ctx.primitives.int32; let int64 = ctx.primitives.int64; let float = ctx.primitives.float; + let boolean = ctx.primitives.bool; let arg_ty = fun.0.args[0].ty; let arg = args[0].1; - if ctx.unifier.unioned(arg_ty, int32) { + if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) + { Some( ctx.builder .build_int_s_extend( @@ -176,10 +191,12 @@ impl TopLevelComposer { |ctx, _, fun, args| { let int32 = ctx.primitives.int32; let int64 = ctx.primitives.int64; + let boolean = ctx.primitives.bool; let float = ctx.primitives.float; let arg_ty = fun.0.args[0].ty; let arg = args[0].1; - if ctx.unifier.unioned(arg_ty, int32) + if ctx.unifier.unioned(arg_ty, boolean) + || ctx.unifier.unioned(arg_ty, int32) || ctx.unifier.unioned(arg_ty, int64) { let arg = args[0].1.into_int_value(); @@ -489,11 +506,7 @@ impl TopLevelComposer { self.definition_ast_list.push((def, Some(ast))); } - let result_ty = if contains_constructor { - Some(constructor_ty) - } else { - None - }; + let result_ty = if contains_constructor { Some(constructor_ty) } else { None }; Ok((class_name, DefinitionId(class_def_id), result_ty)) } @@ -904,11 +917,12 @@ impl TopLevelComposer { .node .annotation .as_ref() - .ok_or_else(|| format!( - "function parameter `{}` at {} need type annotation", - x.node.arg, - x.location - ))? + .ok_or_else(|| { + format!( + "function parameter `{}` at {} need type annotation", + x.node.arg, x.location + ) + })? .as_ref(); let type_annotation = parse_ast_to_type_annotation_kinds( @@ -1110,11 +1124,12 @@ impl TopLevelComposer { .node .annotation .as_ref() - .ok_or_else(|| format!( - "type annotation for `{}` at {} needed", - x.node.arg, - x.location - ))? + .ok_or_else(|| { + format!( + "type annotation for `{}` at {} needed", + x.node.arg, x.location + ) + })? .as_ref(); parse_ast_to_type_annotation_kinds( class_resolver.as_ref(), @@ -1629,7 +1644,9 @@ impl TopLevelComposer { let ret_str = self.unifier.stringify( inst_ret, &mut |id| { - if let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() { + if let TopLevelDef::Class { name, .. } = + &*def_ast_list[id].0.read() + { name.to_string() } else { unreachable!("must be class id here") @@ -1639,7 +1656,9 @@ impl TopLevelComposer { ); return Err(format!( "expected return type of `{}` in function `{}` at {}", - ret_str, name, ast.as_ref().unwrap().location + ret_str, + name, + ast.as_ref().unwrap().location )); }