Compare commits

..

13 Commits

Author SHA1 Message Date
David Mak 685475c6f5 core: Require matching operand sign-ness of bitwise shift operators 2023-11-06 14:20:13 +08:00
David Mak 0cc6d4b82b core: Add compile-time error and runtime assertion for negative shifts 2023-11-04 14:16:18 +08:00
David Mak c2ab6b58ff artiq: Implement `with legacy_parallel` block 2023-11-04 13:42:44 +08:00
David Mak 0a84f7ac31 Add CodeGenerator::gen_block and refactor to use it 2023-11-04 13:42:44 +08:00
David Mak fd787ca3f5 core: Remove trunc
The behavior of trunc is already implemented by casts and is therefore
redundant.
2023-11-04 13:35:53 +08:00
David Mak 4dbe07a0c0 core: Revert breaking changes to round-family functions
These functions should return ints as the math.* functions do instead of
following the convention of numpy.* functions.
2023-11-04 13:35:53 +08:00
David Mak 2e055e8ab1 core: Replace rint implementation with LLVM intrinsic 2023-11-04 13:35:53 +08:00
David Mak 9d737743c1 standalone: Add regression test for numeric primitive operations 2023-11-03 16:24:26 +08:00
David Mak c6b9aefe00 core: Fix int32-to-uint64 conversion
This conversion should be sign-extended.
2023-11-03 16:24:26 +08:00
David Mak 8ad09748d0 core: Fix conversion from float to unsigned types
These conversions also need to wraparound.
2023-11-03 16:24:26 +08:00
David Mak 7a5a2db842 core: Fix handling of float-to-int32 casts
Out-of-bound conversions should be wrapped around.
2023-11-03 16:24:26 +08:00
David Mak 447eb9c387 standalone: Fix output format string for output_uint* 2023-11-03 16:24:26 +08:00
David Mak 92d6f0a5d3 core: Implement bitwise not for unsigned ints and fix implementation 2023-11-03 16:24:26 +08:00
13 changed files with 658 additions and 178 deletions

View File

@ -26,6 +26,24 @@ use std::{
sync::Arc,
};
/// The parallelism mode within a block.
#[derive(Copy, Clone, Eq, PartialEq)]
enum ParallelMode {
/// No parallelism is currently registered for this context.
None,
/// Legacy (or shallow) parallelism. Default before NAC3.
///
/// Each statement within the `with` block is treated as statements to be executed in parallel.
Legacy,
/// Deep parallelism. Default since NAC3.
///
/// Each function call within the `with` block (except those within a nested `sequential` block)
/// are treated to be executed in parallel.
Deep
}
pub struct ArtiqCodeGenerator<'a> {
name: String,
@ -41,6 +59,12 @@ pub struct ArtiqCodeGenerator<'a> {
/// Variable for tracking the end of a `with parallel` block.
end: Option<Expr<Option<Type>>>,
timeline: &'a (dyn TimeFns + Sync),
/// The [ParallelMode] of the current parallel context.
///
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
/// statement, which is used to determine when and how the timeline should be updated.
parallel_mode: ParallelMode,
}
impl<'a> ArtiqCodeGenerator<'a> {
@ -57,6 +81,7 @@ impl<'a> ArtiqCodeGenerator<'a> {
start: None,
end: None,
timeline,
parallel_mode: ParallelMode::None,
}
}
@ -141,6 +166,31 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
}
}
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item=&'c Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I
) -> Result<(), String> where Self: Sized {
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
// in the parallel block
if self.parallel_mode == ParallelMode::Legacy {
for stmt in stmts {
self.gen_stmt(ctx, stmt)?;
if ctx.is_terminated() {
break;
}
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
self.timeline_reset_start(ctx)?;
}
Ok(())
} else {
gen_block(self, ctx, stmts)
}
}
fn gen_call<'ctx, 'a>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
@ -150,8 +200,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
let result = gen_call(self, ctx, obj, fun, params)?;
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
self.timeline_reset_start(ctx)?;
// Deep parallel emits timeline end-update/timeline-reset after each function call
if self.parallel_mode == ParallelMode::Deep {
self.timeline_update_end_max(ctx, self.end.clone(), Some("end"))?;
self.timeline_reset_start(ctx)?;
}
Ok(result)
}
@ -179,9 +232,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// - If there is a end variable, it indicates that we are (indirectly) inside a
// parallel block, and we should update the max end value.
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
if id == &"parallel".into() {
if id == &"parallel".into() || id == &"legacy_parallel".into() {
let old_start = self.start.take();
let old_end = self.end.take();
let old_parallel_mode = self.parallel_mode;
let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)?
@ -224,8 +278,13 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
ctx.builder.build_store(end, now);
self.end = Some(end_expr);
self.name_counter += 1;
self.parallel_mode = match id.to_string().as_str() {
"parallel" => ParallelMode::Deep,
"legacy_parallel" => ParallelMode::Legacy,
_ => unreachable!(),
};
gen_block(self, ctx, body.iter())?;
self.gen_block(ctx, body.iter())?;
let current = ctx.builder.get_insert_block().unwrap();
@ -258,8 +317,9 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// inside a parallel block, should update the outer max now_mu
self.timeline_update_end_max(ctx, old_end.clone(), Some("outer.end"))?;
self.start = old_start;
self.parallel_mode = old_parallel_mode;
self.end = old_end;
self.start = old_start;
if reset_position {
ctx.builder.position_at_end(current);
@ -267,12 +327,20 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
return Ok(());
} else if id == &"sequential".into() {
// For deep parallel, temporarily take away start to avoid function calls in
// the block from resetting the timeline.
// This does not affect legacy parallel, as the timeline will be reset after
// this block finishes execution.
let start = self.start.take();
gen_block(self, ctx, body.iter())?;
self.gen_block(ctx, body.iter())?;
self.start = start;
// Reset the timeline when we are exiting the sequential block
self.timeline_reset_start(ctx)?;
// Legacy parallel does not need this, since it will be reset after codegen
// for this statement is completed
if self.parallel_mode == ParallelMode::Deep {
self.timeline_reset_start(ctx)?;
}
return Ok(());
}

View File

@ -1381,25 +1381,16 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
}
_ => val.into(),
}
} else if [ctx.primitives.int32, ctx.primitives.int64].contains(&ty) {
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
let val = val.into_int_value();
match op {
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").into(),
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").into(),
ast::Unaryop::Not => ctx
.builder
.build_int_compare(
IntPredicate::EQ,
val,
val.get_type().const_zero(),
"not",
)
.into(),
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").into(),
_ => val.into(),
}
} else if ty == ctx.primitives.float {
let val =
if let BasicValueEnum::FloatValue(val) = val { val } else { unreachable!() };
let val = val.into_float_value();
match op {
ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").into(),
ast::Unaryop::Not => ctx

View File

@ -169,6 +169,7 @@ pub trait CodeGenerator {
}
/// Generate code for a statement
///
/// Return true if the statement must early return
fn gen_stmt<'ctx, 'a>(
&mut self,
@ -181,6 +182,18 @@ pub trait CodeGenerator {
gen_stmt(self, ctx, stmt)
}
/// Generates code for a block statement.
fn gen_block<'ctx, 'a, 'b, I: Iterator<Item = &'b Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I,
) -> Result<(), String>
where
Self: Sized,
{
gen_block(self, ctx, stmts)
}
/// See [bool_to_i1].
fn bool_to_i1<'ctx, 'a>(
&self,

View File

@ -1,5 +1,4 @@
use crate::{
codegen::stmt::gen_block,
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{TopLevelContext, TopLevelDef},
typecheck::{
@ -859,7 +858,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>(
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let body = task.body.clone();
gen_func_impl(context, generator, registry, builder, module, task, |generator, ctx| {
gen_block(generator, ctx, body.iter())
generator.gen_block(ctx, body.iter())
})
}

View File

@ -322,7 +322,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
ctx.builder.position_at_end(body_bb);
ctx.builder.build_store(target_i, ctx.builder.build_load(i, "").into_int_value());
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
} else {
let index_addr = generator.gen_var_alloc(ctx, size_t.into(), Some("for.index.addr"))?;
ctx.builder.build_store(index_addr, size_t.const_zero());
@ -353,7 +353,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
let index = ctx.builder.build_load(index_addr, "for.index").into_int_value();
let val = ctx.build_gep_and_load(arr_ptr, &[index], Some("val"));
generator.gen_assign(ctx, target, val.into())?;
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
}
for (k, (_, _, counter)) in var_assignment.iter() {
@ -369,7 +369,7 @@ pub fn gen_for<'ctx, 'a, G: CodeGenerator>(
if !orelse.is_empty() {
ctx.builder.position_at_end(orelse_bb);
gen_block(generator, ctx, orelse.iter())?;
generator.gen_block(ctx, orelse.iter())?;
if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(cont_bb);
}
@ -423,7 +423,7 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator>(
unreachable!()
};
ctx.builder.position_at_end(body_bb);
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
for (k, (_, _, counter)) in var_assignment.iter() {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 {
@ -435,7 +435,7 @@ pub fn gen_while<'ctx, 'a, G: CodeGenerator>(
}
if !orelse.is_empty() {
ctx.builder.position_at_end(orelse_bb);
gen_block(generator, ctx, orelse.iter())?;
generator.gen_block(ctx, orelse.iter())?;
if !ctx.is_terminated() {
ctx.builder.build_unconditional_branch(cont_bb);
}
@ -487,7 +487,7 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>(
unreachable!()
};
ctx.builder.position_at_end(body_bb);
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
for (k, (_, _, counter)) in var_assignment.iter() {
let (_, static_val, counter2) = ctx.var_assignment.get_mut(k).unwrap();
if counter != counter2 {
@ -503,7 +503,7 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator>(
}
if !orelse.is_empty() {
ctx.builder.position_at_end(orelse_bb);
gen_block(generator, ctx, orelse.iter())?;
generator.gen_block(ctx, orelse.iter())?;
if !ctx.is_terminated() {
if cont_bb.is_none() {
cont_bb = Some(ctx.ctx.append_basic_block(current, "cont"));
@ -792,9 +792,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
}
let old_clauses = ctx.outer_catch_clauses.replace((all_clauses, dispatcher, exn));
let old_unwind = ctx.unwind_target.replace(landingpad);
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
if ctx.builder.get_insert_block().unwrap().get_terminator().is_none() {
gen_block(generator, ctx, orelse.iter())?;
generator.gen_block(ctx, orelse.iter())?;
}
let body = ctx.builder.get_insert_block().unwrap();
// reset old_clauses and old_unwind
@ -896,7 +896,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx.var_assignment.insert(*name, (exn_store, None, 0));
ctx.builder.build_store(exn_store, exn.as_basic_value());
}
gen_block(generator, ctx, body.iter())?;
generator.gen_block(ctx, body.iter())?;
let current = ctx.builder.get_insert_block().unwrap();
// only need to call end catch if not terminated
// otherwise, we already handled in return/break/continue/raise
@ -979,7 +979,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
// exception path
let cleanup = cleanup.unwrap();
ctx.builder.position_at_end(cleanup);
gen_block(generator, ctx, finalbody.iter())?;
generator.gen_block(ctx, finalbody.iter())?;
if !ctx.is_terminated() {
ctx.build_call_or_invoke(resume, &[], "resume");
ctx.builder.build_unreachable();
@ -991,7 +991,7 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
final_targets.push(tail);
let finalizer = ctx.ctx.append_basic_block(current_fun, "try.finally");
ctx.builder.position_at_end(finalizer);
gen_block(generator, ctx, finalbody.iter())?;
generator.gen_block(ctx, finalbody.iter())?;
if !ctx.is_terminated() {
let dest = ctx.builder.build_load(final_state, "final_dest");
ctx.builder.build_indirect_branch(dest, &final_targets);

View File

@ -477,12 +477,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float;
let boolean = ctx.primitives.bool;
let PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool: boolean,
..
} = ctx.primitives;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(if ctx.unifier.unioned(arg_ty, boolean) {
@ -512,15 +516,21 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.into(),
)
} else if ctx.unifier.unioned(arg_ty, float) {
let val = ctx
let to_int64 = ctx
.builder
.build_float_to_signed_int(
arg.into_float_value(),
ctx.ctx.i64_type(),
"",
);
let val = ctx.builder
.build_int_truncate(
to_int64,
ctx.ctx.i32_type(),
"fptosi",
)
.into();
Some(val)
"conv",
);
Some(val.into())
} else {
unreachable!()
})
@ -542,12 +552,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float;
let boolean = ctx.primitives.bool;
let PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool: boolean,
..
} = ctx.primitives;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(
@ -609,12 +623,16 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float;
let boolean = ctx.primitives.bool;
let PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool: boolean,
..
} = ctx.primitives;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let res = if ctx.unifier.unioned(arg_ty, boolean) {
@ -632,13 +650,34 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.build_int_truncate(arg.into_int_value(), ctx.ctx.i32_type(), "trunc")
.into()
} else if ctx.unifier.unioned(arg_ty, float) {
ctx.builder
let llvm_i32 = ctx.ctx.i32_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = arg.into_float_value();
let arg_gez = ctx.builder
.build_float_compare(FloatPredicate::OGE, arg, arg.get_type().const_zero(), "");
let to_int32 = ctx.builder
.build_float_to_signed_int(
arg,
llvm_i32,
""
);
let to_uint64 = ctx.builder
.build_float_to_unsigned_int(
arg.into_float_value(),
ctx.ctx.i32_type(),
"ftoi",
)
.into()
arg,
llvm_i64,
""
);
let val = ctx.builder.build_select(
arg_gez,
ctx.builder.build_int_truncate(to_uint64, llvm_i32, ""),
to_int32,
"conv"
);
val.into()
} else {
unreachable!();
};
@ -661,33 +700,60 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let uint32 = ctx.primitives.uint32;
let uint64 = ctx.primitives.uint64;
let float = ctx.primitives.float;
let boolean = ctx.primitives.bool;
let PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool: boolean,
..
} = ctx.primitives;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
let res = if ctx.unifier.unioned(arg_ty, int32)
|| ctx.unifier.unioned(arg_ty, uint32)
let res = if ctx.unifier.unioned(arg_ty, uint32)
|| ctx.unifier.unioned(arg_ty, boolean)
{
ctx.builder
.build_int_z_extend(arg.into_int_value(), ctx.ctx.i64_type(), "zext")
.into()
} else if ctx.unifier.unioned(arg_ty, int32) {
ctx.builder
.build_int_s_extend(arg.into_int_value(), ctx.ctx.i64_type(), "sext")
.into()
} else if ctx.unifier.unioned(arg_ty, int64)
|| ctx.unifier.unioned(arg_ty, uint64)
{
arg
} else if ctx.unifier.unioned(arg_ty, float) {
ctx.builder
let llvm_i64 = ctx.ctx.i64_type();
let arg = arg.into_float_value();
let arg_gez = ctx.builder
.build_float_compare(FloatPredicate::OGE, arg, arg.get_type().const_zero(), "");
let to_int64 = ctx.builder
.build_float_to_signed_int(
arg,
llvm_i64,
""
);
let to_uint64 = ctx.builder
.build_float_to_unsigned_int(
arg.into_float_value(),
ctx.ctx.i64_type(),
"ftoi",
)
.into()
arg,
llvm_i64,
""
);
let val = ctx.builder.build_select(
arg_gez,
to_uint64,
to_int64,
"conv"
);
val.into()
} else {
unreachable!();
};
@ -710,16 +776,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|ctx, _, fun, args, generator| {
let int32 = ctx.primitives.int32;
let int64 = ctx.primitives.int64;
let boolean = ctx.primitives.bool;
let float = ctx.primitives.float;
let PrimitiveStore {
int32,
int64,
uint32,
uint64,
float,
bool: boolean,
..
} = ctx.primitives;
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
Ok(
if ctx.unifier.unioned(arg_ty, boolean)
|| ctx.unifier.unioned(arg_ty, int32)
|| ctx.unifier.unioned(arg_ty, int64)
if [boolean, int32, int64].iter().any(|ty| ctx.unifier.unioned(arg_ty, *ty))
{
let arg = arg.into_int_value();
let val = ctx
@ -727,6 +797,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.build_signed_int_to_float(arg, ctx.ctx.f64_type(), "sitofp")
.into();
Some(val)
} else if [uint32, uint64].iter().any(|ty| ctx.unifier.unioned(arg_ty, *ty)) {
let arg = arg.into_int_value();
let val = ctx
.builder
.build_unsigned_int_to_float(arg, ctx.ctx.f64_type(), "uitofp")
.into();
Some(val)
} else if ctx.unifier.unioned(arg_ty, float) {
Some(arg)
} else {
@ -737,6 +814,66 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_codegen(
primitives,
&var_map,
"round",
int32,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "round");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"round64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "round");
Ok(Some(val_toint.into()))
}),
),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "range".into(),
simple_name: "range".into(),
@ -919,21 +1056,125 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
create_fn_by_intrinsic(
create_fn_by_codegen(
primitives,
&var_map,
"floor",
float,
&[(float, "x")],
"llvm.floor.f64",
int32,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "floor");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_intrinsic(
create_fn_by_codegen(
primitives,
&var_map,
"floor64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.floor.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.floor.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "floor");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"ceil",
float,
&[(float, "x")],
"llvm.ceil.f64",
int32,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i32 = ctx.ctx.i32_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i32, "ceil");
Ok(Some(val_toint.into()))
}),
),
create_fn_by_codegen(
primitives,
&var_map,
"ceil64",
int64,
&[(float, "n")],
Box::new(|ctx, _, _, args, generator| {
let llvm_f64 = ctx.ctx.f64_type();
let llvm_i64 = ctx.ctx.i64_type();
let arg = args[0].1.clone()
.to_basic_value_enum(ctx, generator, ctx.primitives.float)?;
let intrinsic_fn = ctx.module.get_function("llvm.ceil.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.ceil.f64", fn_type, None)
});
let val = ctx
.builder
.build_call(intrinsic_fn, &[arg.into()], "")
.try_as_basic_value()
.left()
.unwrap();
let val_toint = ctx.builder
.build_float_to_signed_int(val.into_float_value(), llvm_i64, "ceil");
Ok(Some(val_toint.into()))
}),
),
Arc::new(RwLock::new({
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
@ -1284,14 +1525,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
&[(float, "x")],
"llvm.fabs.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
"trunc",
float,
&[(float, "x")],
"llvm.trunc.f64",
),
create_fn_by_intrinsic(
primitives,
&var_map,
@ -1300,49 +1533,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
&[(float, "x")],
"llvm.sqrt.f64",
),
create_fn_by_codegen(
create_fn_by_intrinsic(
primitives,
&var_map,
"rint",
float,
&[(float, "x")],
Box::new(|ctx, _, fun, args, generator| {
let float = ctx.primitives.float;
let llvm_f64 = ctx.ctx.f64_type();
let x_ty = fun.0.args[0].ty;
let x_val = args[0].1.clone()
.to_basic_value_enum(ctx, generator, x_ty)?;
assert!(ctx.unifier.unioned(x_ty, float));
let intrinsic_fn = ctx.module.get_function("llvm.round.f64").unwrap_or_else(|| {
let fn_type = llvm_f64.fn_type(&[llvm_f64.into()], false);
ctx.module.add_function("llvm.round.f64", fn_type, None)
});
// rint(x) == round(x * 0.5) * 2.0
// %0 = fmul f64 %x, 0.5
let x_half = ctx.builder
.build_float_mul(x_val.into_float_value(), llvm_f64.const_float(0.5), "");
// %1 = call f64 @llvm.round.f64(f64 %0)
let round = ctx.builder
.build_call(
intrinsic_fn,
&vec![x_half.into()],
"",
)
.try_as_basic_value()
.left()
.unwrap();
// %2 = fmul f64 %1, 2.0
let val = ctx.builder
.build_float_mul(round.into_float_value(), llvm_f64.const_float(2.0).into(), "rint");
Ok(Some(val.into()))
}),
"llvm.roundeven.f64",
),
create_fn_by_extern(
primitives,
@ -1774,11 +1971,15 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"uint32",
"uint64",
"float",
"round",
"round64",
"range",
"str",
"bool",
"floor",
"floor64",
"ceil",
"ceil64",
"len",
"min",
"max",
@ -1793,7 +1994,6 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
"log10",
"log2",
"fabs",
"trunc",
"sqrt",
"rint",
"tan",

View File

@ -224,10 +224,15 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
/// LShift, RShift
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32], ty, &[ast::Operator::LShift, ast::Operator::RShift]);
if !unifier.unioned(ty, store.int32) {
impl_binop(unifier, store, ty, &[ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]);
}
let rhs_ty = if [store.int32, store.int64].contains(&ty) {
store.int32
} else if [store.uint32, store.uint64].contains(&ty) {
store.uint32
} else {
unreachable!()
};
impl_binop(unifier, store, ty, &[rhs_ty], ty, &[ast::Operator::LShift, ast::Operator::RShift]);
}
/// Div

View File

@ -34,11 +34,11 @@ void output_int64(int64_t x) {
}
void output_uint32(uint32_t x) {
printf("%d\n", x);
printf("%u\n", x);
}
void output_uint64(uint64_t x) {
printf("%ld\n", x);
printf("%lu\n", x);
}
void output_float64(double x) {

View File

@ -3,6 +3,7 @@
import sys
import importlib.util
import importlib.machinery
import math
import numpy as np
import pathlib
import scipy
@ -43,6 +44,12 @@ def Some(v: T) -> Option[T]:
none = Option(None)
def round_away_zero(x):
if x >= 0.0:
return math.floor(x + 0.5)
else:
return math.ceil(x - 0.5)
def patch(module):
def dbl_nan():
return np.nan
@ -98,6 +105,14 @@ def patch(module):
module.Some = Some
module.none = none
# Builtin Math functions
module.round = round_away_zero
module.round64 = round_away_zero
module.floor = math.floor
module.floor64 = math.floor
module.ceil = math.ceil
module.ceil64 = math.ceil
# NumPy Math functions
module.isnan = np.isnan
module.isinf = np.isinf
@ -109,8 +124,6 @@ def patch(module):
module.log10 = np.log10
module.log2 = np.log2
module.fabs = np.fabs
module.floor = np.floor
module.ceil = np.ceil
module.trunc = np.trunc
module.sqrt = np.sqrt
module.rint = np.rint

View File

@ -1,16 +0,0 @@
@extern
def output_int64(x: int64):
...
def int64_min() -> int64:
return int64(-9223372036854775808)
def int64_max() -> int64:
return int64(9223372036854775807)
def run() -> int32:
output_int64(int64(1) << 32)
output_int64(int64_min() >> 63)
output_int64(int64_max() >> 63)
return 0

View File

@ -2,6 +2,14 @@
def output_bool(x: bool):
...
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
@extern
def output_float64(x: float):
...
@ -20,6 +28,14 @@ def dbl_pi() -> float:
def dbl_e() -> float:
return 2.71828182845904523536028747135266249775724709369995
def test_round():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(round(x))
def test_round64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(round64(x))
def test_isnan():
for x in [dbl_nan(), 0.0, dbl_inf()]:
output_bool(isnan(x))
@ -64,16 +80,20 @@ def test_fabs():
output_float64(fabs(x))
def test_floor():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(floor(x))
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(floor(x))
def test_floor64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(floor64(x))
def test_ceil():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(ceil(x))
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int32(ceil(x))
def test_trunc():
for x in [-1.5, -0.5, 0.5, 1.5, dbl_inf(), -dbl_inf(), dbl_nan()]:
output_float64(trunc(x))
def test_ceil64():
for x in [-1.5, -0.5, 0.5, 1.5]:
output_int64(ceil64(x))
def test_sqrt():
for x in [1.0, 2.0, 4.0, dbl_inf(), -dbl_inf(), dbl_nan()]:
@ -192,6 +212,8 @@ def test_nextafter():
output_float64(nextafter(x1, x2))
def run() -> int32:
test_round()
test_round64()
test_isnan()
test_isinf()
test_sin()
@ -203,8 +225,9 @@ def run() -> int32:
test_log2()
test_fabs()
test_floor()
test_floor64()
test_ceil()
test_trunc()
test_ceil64()
test_sqrt()
test_rint()
test_tan()

View File

@ -0,0 +1,184 @@
@extern
def output_bool(x: bool):
...
@extern
def output_int32(x: int32):
...
@extern
def output_int64(x: int64):
...
@extern
def output_uint32(x: uint32):
...
@extern
def output_uint64(x: uint64):
...
@extern
def output_float64(x: float):
...
def u32_min() -> uint32:
return uint32(0)
def u32_max() -> uint32:
return ~uint32(0)
def i32_min() -> int32:
return int32(1 << 31)
def i32_max() -> int32:
return int32(~(1 << 31))
def u64_min() -> uint64:
return uint64(0)
def u64_max() -> uint64:
return ~uint64(0)
def i64_min() -> int64:
return int64(1) << int64(63)
def i64_max() -> int64:
return ~(int64(1) << int64(63))
def test_u32_bnot():
output_uint32(~uint32(0))
def test_u64_bnot():
output_uint64(~uint64(0))
def test_conv_from_i32():
for x in [
i32_min(),
i32_min() + 1,
-1,
0,
1,
i32_max() - 1,
i32_max()
]:
output_int64(int64(x))
output_uint32(uint32(x))
output_uint64(uint64(x))
output_float64(float(x))
def test_conv_from_u32():
for x in [
u32_min(),
u32_min() + uint32(1),
u32_max() - uint32(1),
u32_max()
]:
output_uint64(uint64(x))
output_int32(int32(x))
output_int64(int64(x))
output_float64(float(x))
def test_conv_from_i64():
for x in [
i64_min(),
i64_min() + int64(1),
int64(-1),
int64(0),
int64(1),
i64_max() - int64(1),
i64_max()
]:
output_int32(int32(x))
output_uint64(uint64(x))
output_uint32(uint32(x))
output_float64(float(x))
def test_conv_from_u64():
for x in [
u64_min(),
u64_min() + uint64(1),
u64_max() - uint64(1),
u64_max()
]:
output_uint32(uint32(x))
output_int64(int64(x))
output_int32(int32(x))
output_float64(float(x))
def test_f64toi32():
for x in [
float(i32_min()) - 1.0,
float(i32_min()),
float(i32_min()) + 1.0,
-1.5,
-0.5,
0.5,
1.5,
float(i32_max()) - 1.0,
float(i32_max()),
float(i32_max()) + 1.0
]:
output_int32(int32(x))
def test_f64toi64():
for x in [
float(i64_min()),
float(i64_min()) + 1.0,
-1.5,
-0.5,
0.5,
1.5,
# 2^53 is the highest integral power-of-two of which uint64 and float have a one-to-one correspondence
float(uint64(2) ** uint64(52)) - 1.0,
float(uint64(2) ** uint64(52)),
float(uint64(2) ** uint64(52)) + 1.0,
]:
output_int64(int64(x))
def test_f64tou32():
for x in [
-1.5,
float(u32_min()) - 1.0,
-0.5,
float(u32_min()),
0.5,
float(u32_min()) + 1.0,
1.5,
float(u32_max()) - 1.0,
float(u32_max()),
float(u32_max()) + 1.0
]:
output_uint32(uint32(x))
def test_f64tou64():
for x in [
-1.5,
float(u64_min()) - 1.0,
-0.5,
float(u64_min()),
0.5,
float(u64_min()) + 1.0,
1.5,
# 2^53 is the highest integral power-of-two of which uint64 and float have a one-to-one correspondence
float(uint64(2) ** uint64(52)) - 1.0,
float(uint64(2) ** uint64(52)),
float(uint64(2) ** uint64(52)) + 1.0,
]:
output_uint64(uint64(x))
def run() -> int32:
test_u32_bnot()
test_u64_bnot()
test_conv_from_i32()
test_conv_from_u32()
test_conv_from_i64()
test_conv_from_u64()
test_f64toi32()
test_f64toi64()
test_f64tou32()
test_f64tou64()
return 0

View File

@ -108,8 +108,8 @@ def test_int64():
output_int64(a | b)
output_int64(a ^ b)
output_int64(a & b)
output_int64(a << b)
output_int64(a >> b)
output_int64(a << int32(b))
output_int64(a >> int32(b))
output_float64(a / b)
a += b
output_int64(a)
@ -127,9 +127,9 @@ def test_int64():
output_int64(a)
a &= b
output_int64(a)
a <<= b
a <<= int32(b)
output_int64(a)
a >>= b
a >>= int32(b)
output_int64(a)
def test_uint64():
@ -143,8 +143,8 @@ def test_uint64():
output_uint64(a | b)
output_uint64(a ^ b)
output_uint64(a & b)
output_uint64(a << b)
output_uint64(a >> b)
output_uint64(a << uint32(b))
output_uint64(a >> uint32(b))
output_float64(a / b)
a += b
output_uint64(a)
@ -162,9 +162,9 @@ def test_uint64():
output_uint64(a)
a &= b
output_uint64(a)
a <<= b
a <<= uint32(b)
output_uint64(a)
a >>= b
a >>= uint32(b)
output_uint64(a)
class A: