diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 954367c8..f2097dd0 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,10 +14,7 @@ use strum::IntoEnumIterator; use crate::{ codegen::{ builtin_fns, - classes::{ - ArrayLikeValue, NDArrayValue, ProxyType, ProxyValue, RangeType, RangeValue, - TypedArrayLikeAccessor, - }, + classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, expr::destructure_range, irrt::*, numpy::*, @@ -460,9 +457,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::Float | PrimDef::Bool | PrimDef::Str - | PrimDef::Range | PrimDef::None => Self::build_simple_primitive_class(prim), + PrimDef::Range | PrimDef::FunRangeInit => self.build_range_class_related(prim), + PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Option @@ -494,7 +492,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunRange => self.build_range_function(), PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -599,7 +596,6 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::Float, PrimDef::Bool, PrimDef::Str, - PrimDef::Range, PrimDef::None, ], ); @@ -607,6 +603,165 @@ impl<'a> BuiltinBuilder<'a> { TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None) } + fn build_range_class_related(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::Range, PrimDef::FunRangeInit]); + + let PrimitiveStore { int32, range, .. } = *self.primitives; + + let make_ctor_signature = |unifier: &mut Unifier| { + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "start".into(), ty: int32, default_value: None }, + FuncArg { + name: "stop".into(), + ty: int32, + // placeholder + default_value: Some(SymbolValue::I32(0)), + }, + FuncArg { + name: "step".into(), + ty: int32, + default_value: Some(SymbolValue::I32(1)), + }, + ], + ret: range, + vars: VarMap::default(), + })) + }; + + match prim { + PrimDef::Range => { + let fields = vec![ + ("start".into(), int32, true), + ("stop".into(), int32, true), + ("step".into(), int32, true), + ]; + let ctor_signature = make_ctor_signature(self.unifier); + + TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: Vec::default(), + fields, + attributes: Vec::default(), + methods: vec![("__init__".into(), ctor_signature, PrimDef::FunRangeInit.id())], + ancestors: Vec::default(), + constructor: Some(ctor_signature), + resolver: None, + loc: None, + } + } + + PrimDef::FunRangeInit => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: make_ctor_signature(self.unifier), + var_id: Vec::default(), + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, args, generator| { + let (zelf_ty, zelf) = obj.unwrap(); + let zelf = + zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); + let zelf = RangeValue::from_ptr_val(zelf, Some("range")); + + let mut start = None; + let mut stop = None; + let mut step = None; + let int32 = ctx.ctx.i32_type(); + let ty_i32 = ctx.primitives.int32; + for (i, arg) in args.iter().enumerate() { + if arg.0 == Some("start".into()) { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("stop".into()) { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("step".into()) { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 0 { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 1 { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 2 { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } + } + let step = match step { + Some(step) => { + // assert step != 0, throw exception if not + let not_zero = ctx + .builder + .build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ) + .unwrap(); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "range() step must not be zero", + [None, None, None], + ctx.current_loc, + ); + step + } + None => int32.const_int(1, false), + }; + let stop = stop.unwrap_or_else(|| { + let v = start.unwrap(); + start = None; + v + }); + let start = start.unwrap_or_else(|| int32.const_zero()); + + zelf.store_start(ctx, start); + zelf.store_end(ctx, stop); + zelf.store_step(ctx, step); + + Ok(Some(zelf.as_base_value().into())) + }, + )))), + loc: None, + }, + + _ => unreachable!(), + } + } + /// Build the class `Exception` and its associated methods. fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef { // NOTE: currently only contains the class `Exception` @@ -1170,131 +1325,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build the `range()` function. - fn build_range_function(&mut self) -> TopLevelDef { - let prim = PrimDef::FunRange; - - let PrimitiveStore { int32, range, .. } = *self.primitives; - - TopLevelDef::Function { - name: prim.name().into(), - simple_name: prim.simple_name().into(), - signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "start".into(), ty: int32, default_value: None }, - FuncArg { - name: "stop".into(), - ty: int32, - // placeholder - default_value: Some(SymbolValue::I32(0)), - }, - FuncArg { - name: "step".into(), - ty: int32, - default_value: Some(SymbolValue::I32(1)), - }, - ], - ret: range, - vars: VarMap::default(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, _, args, generator| { - let mut start = None; - let mut stop = None; - let mut step = None; - let int32 = ctx.ctx.i32_type(); - let ty_i32 = ctx.primitives.int32; - for (i, arg) in args.iter().enumerate() { - if arg.0 == Some("start".into()) { - start = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if arg.0 == Some("stop".into()) { - stop = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if arg.0 == Some("step".into()) { - step = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 0 { - start = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 1 { - stop = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 2 { - step = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } - } - let step = match step { - Some(step) => { - // assert step != 0, throw exception if not - let not_zero = ctx - .builder - .build_int_compare( - IntPredicate::NE, - step, - step.get_type().const_zero(), - "range_step_ne", - ) - .unwrap(); - ctx.make_assert( - generator, - not_zero, - "0:ValueError", - "range() step must not be zero", - [None, None, None], - ctx.current_loc, - ); - step - } - None => int32.const_int(1, false), - }; - let stop = stop.unwrap_or_else(|| { - let v = start.unwrap(); - start = None; - v - }); - let start = start.unwrap_or_else(|| int32.const_zero()); - - let ptr = RangeType::new(ctx.ctx).new_value(generator, ctx, Some("range")); - ptr.store_start(ctx, start); - ptr.store_end(ctx, stop); - ptr.store_step(ctx, step); - Ok(Some(ptr.as_base_value().into())) - }, - )))), - loc: None, - } - } - /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9fef74eb..76d454cf 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -49,7 +49,7 @@ pub enum PrimDef { FunRound, FunRound64, FunNpRound, - FunRange, + FunRangeInit, FunStr, FunBool, FunFloor, @@ -203,7 +203,7 @@ impl PrimDef { PrimDef::FunRound => fun("round", None), PrimDef::FunRound64 => fun("round64", None), PrimDef::FunNpRound => fun("np_round", None), - PrimDef::FunRange => fun("range", None), + PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")), PrimDef::FunStr => fun("str", None), PrimDef::FunBool => fun("bool", None), PrimDef::FunFloor => fun("floor", None), diff --git a/nac3standalone/demo/demo.c b/nac3standalone/demo/demo.c index e674a5f5..28c64338 100644 --- a/nac3standalone/demo/demo.c +++ b/nac3standalone/demo/demo.c @@ -44,6 +44,18 @@ void output_float64(double x) { } } +void output_range(int32_t range[3]) { + printf("range("); + if (range[0] != 0) { + printf("%d, ", range[0]); + } + printf("%d", range[1]); + if (range[2] != 1) { + printf(", %d", range[2]); + } + puts(")"); +} + void output_asciiart(int32_t x) { static const char *chars = " .,-:;i+hHM$*#@ "; if (x < 0) { @@ -79,6 +91,10 @@ void output_str(struct cslice *slice) { for (usize i = 0; i < slice->len; ++i) { putchar(data[i]); } +} + +void output_strln(struct cslice *slice) { + output_str(slice); putchar('\n'); } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 1b68bea6..5167ae8e 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -107,6 +107,9 @@ def patch(module): def output_float(x): print("%f" % x) + def output_strln(x): + print(x, end='') + def dbg_stack_address(_): return 0 @@ -120,6 +123,8 @@ def patch(module): return output_asciiart elif name == "output_float64": return output_float + elif name == "output_str": + return output_strln elif name in { "output_bool", "output_int32", @@ -127,7 +132,7 @@ def patch(module): "output_int32_list", "output_uint32", "output_uint64", - "output_str", + "output_strln", }: return print elif name == "dbg_stack_address": diff --git a/nac3standalone/demo/src/classes.py b/nac3standalone/demo/src/classes.py index b00fa776..ff66064e 100644 --- a/nac3standalone/demo/src/classes.py +++ b/nac3standalone/demo/src/classes.py @@ -7,7 +7,7 @@ def output_int64(x: int64): ... @extern -def output_str(x: str): +def output_strln(x: str): ... @@ -33,7 +33,7 @@ class A: class Initless: def foo(self): - output_str("hello") + output_strln("hello") def run() -> int32: a = A(10) diff --git a/nac3standalone/demo/src/demo_test.py b/nac3standalone/demo/src/demo_test.py index bc20c879..951bbafd 100644 --- a/nac3standalone/demo/src/demo_test.py +++ b/nac3standalone/demo/src/demo_test.py @@ -22,6 +22,10 @@ def output_uint64(x: uint64): def output_float64(x: float): ... +@extern +def output_range(x: range): + ... + @extern def output_int32_list(x: list[int32]): ... @@ -34,6 +38,10 @@ def output_asciiart(x: int32): def output_str(x: str): ... +@extern +def output_strln(x: str): + ... + def test_output_bool(): output_bool(True) output_bool(False) @@ -59,6 +67,15 @@ def test_output_float64(): output_float64(16.25) output_float64(-16.25) +def test_output_range(): + r = range(1, 100, 5) + output_int32(r.start) + output_int32(r.stop) + output_int32(r.step) + output_range(range(10)) + output_range(range(1, 10)) + output_range(range(1, 10, 2)) + def test_output_asciiart(): for i in range(17): output_asciiart(i) @@ -68,7 +85,8 @@ def test_output_int32_list(): output_int32_list([0, 1, 3, 5, 10]) def test_output_str_family(): - output_str("hello world") + output_str("hello") + output_strln(" world") def run() -> int32: test_output_bool() @@ -77,6 +95,7 @@ def run() -> int32: test_output_uint32() test_output_uint64() test_output_float64() + test_output_range() test_output_asciiart() test_output_int32_list() test_output_str_family() diff --git a/nac3standalone/demo/src/loop_try_break.py b/nac3standalone/demo/src/loop_try_break.py index 2592bacd..8b19da18 100644 --- a/nac3standalone/demo/src/loop_try_break.py +++ b/nac3standalone/demo/src/loop_try_break.py @@ -23,11 +23,12 @@ def run() -> int32: output_int32(x) output_str(" * ") output_float64(n / x) + output_str("\n") except: # Assume this is intended to catch x == 0 break else: # loop fell through without finding a factor output_int32(n) - output_str(" is a prime number") + output_str(" is a prime number\n") return 0 \ No newline at end of file