From 2cfb7a7e10c77a623b74a03881a011e0ebe679a3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 8 Jul 2024 14:22:19 +0800 Subject: [PATCH] core: Refactor range function into constructor --- nac3core/src/toplevel/builtins.rs | 294 ++++++++++++++++-------------- nac3core/src/toplevel/helper.rs | 4 +- 2 files changed, 164 insertions(+), 134 deletions(-) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 954367c80..f2097dd0d 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,10 +14,7 @@ use strum::IntoEnumIterator; use crate::{ codegen::{ builtin_fns, - classes::{ - ArrayLikeValue, NDArrayValue, ProxyType, ProxyValue, RangeType, RangeValue, - TypedArrayLikeAccessor, - }, + classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, expr::destructure_range, irrt::*, numpy::*, @@ -460,9 +457,10 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::Float | PrimDef::Bool | PrimDef::Str - | PrimDef::Range | PrimDef::None => Self::build_simple_primitive_class(prim), + PrimDef::Range | PrimDef::FunRangeInit => self.build_range_class_related(prim), + PrimDef::Exception => self.build_exception_class_related(prim), PrimDef::Option @@ -494,7 +492,6 @@ impl<'a> BuiltinBuilder<'a> { | PrimDef::FunNpEye | PrimDef::FunNpIdentity => self.build_ndarray_other_factory_function(prim), - PrimDef::FunRange => self.build_range_function(), PrimDef::FunStr => self.build_str_function(), PrimDef::FunFloor | PrimDef::FunFloor64 | PrimDef::FunCeil | PrimDef::FunCeil64 => { @@ -599,7 +596,6 @@ impl<'a> BuiltinBuilder<'a> { PrimDef::Float, PrimDef::Bool, PrimDef::Str, - PrimDef::Range, PrimDef::None, ], ); @@ -607,6 +603,165 @@ impl<'a> BuiltinBuilder<'a> { TopLevelComposer::make_top_level_class_def(prim.id(), None, prim.name().into(), None, None) } + fn build_range_class_related(&mut self, prim: PrimDef) -> TopLevelDef { + debug_assert_prim_is_allowed(prim, &[PrimDef::Range, PrimDef::FunRangeInit]); + + let PrimitiveStore { int32, range, .. } = *self.primitives; + + let make_ctor_signature = |unifier: &mut Unifier| { + unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "start".into(), ty: int32, default_value: None }, + FuncArg { + name: "stop".into(), + ty: int32, + // placeholder + default_value: Some(SymbolValue::I32(0)), + }, + FuncArg { + name: "step".into(), + ty: int32, + default_value: Some(SymbolValue::I32(1)), + }, + ], + ret: range, + vars: VarMap::default(), + })) + }; + + match prim { + PrimDef::Range => { + let fields = vec![ + ("start".into(), int32, true), + ("stop".into(), int32, true), + ("step".into(), int32, true), + ]; + let ctor_signature = make_ctor_signature(self.unifier); + + TopLevelDef::Class { + name: prim.name().into(), + object_id: prim.id(), + type_vars: Vec::default(), + fields, + attributes: Vec::default(), + methods: vec![("__init__".into(), ctor_signature, PrimDef::FunRangeInit.id())], + ancestors: Vec::default(), + constructor: Some(ctor_signature), + resolver: None, + loc: None, + } + } + + PrimDef::FunRangeInit => TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: make_ctor_signature(self.unifier), + var_id: Vec::default(), + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, args, generator| { + let (zelf_ty, zelf) = obj.unwrap(); + let zelf = + zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); + let zelf = RangeValue::from_ptr_val(zelf, Some("range")); + + let mut start = None; + let mut stop = None; + let mut step = None; + let int32 = ctx.ctx.i32_type(); + let ty_i32 = ctx.primitives.int32; + for (i, arg) in args.iter().enumerate() { + if arg.0 == Some("start".into()) { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("stop".into()) { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("step".into()) { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 0 { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 1 { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 2 { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } + } + let step = match step { + Some(step) => { + // assert step != 0, throw exception if not + let not_zero = ctx + .builder + .build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ) + .unwrap(); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "range() step must not be zero", + [None, None, None], + ctx.current_loc, + ); + step + } + None => int32.const_int(1, false), + }; + let stop = stop.unwrap_or_else(|| { + let v = start.unwrap(); + start = None; + v + }); + let start = start.unwrap_or_else(|| int32.const_zero()); + + zelf.store_start(ctx, start); + zelf.store_end(ctx, stop); + zelf.store_step(ctx, step); + + Ok(Some(zelf.as_base_value().into())) + }, + )))), + loc: None, + }, + + _ => unreachable!(), + } + } + /// Build the class `Exception` and its associated methods. fn build_exception_class_related(&self, prim: PrimDef) -> TopLevelDef { // NOTE: currently only contains the class `Exception` @@ -1170,131 +1325,6 @@ impl<'a> BuiltinBuilder<'a> { } } - /// Build the `range()` function. - fn build_range_function(&mut self) -> TopLevelDef { - let prim = PrimDef::FunRange; - - let PrimitiveStore { int32, range, .. } = *self.primitives; - - TopLevelDef::Function { - name: prim.name().into(), - simple_name: prim.simple_name().into(), - signature: self.unifier.add_ty(TypeEnum::TFunc(FunSignature { - args: vec![ - FuncArg { name: "start".into(), ty: int32, default_value: None }, - FuncArg { - name: "stop".into(), - ty: int32, - // placeholder - default_value: Some(SymbolValue::I32(0)), - }, - FuncArg { - name: "step".into(), - ty: int32, - default_value: Some(SymbolValue::I32(1)), - }, - ], - ret: range, - vars: VarMap::default(), - })), - var_id: Vec::default(), - instance_to_symbol: HashMap::default(), - instance_to_stmt: HashMap::default(), - resolver: None, - codegen_callback: Some(Arc::new(GenCall::new(Box::new( - |ctx, _, _, args, generator| { - let mut start = None; - let mut stop = None; - let mut step = None; - let int32 = ctx.ctx.i32_type(); - let ty_i32 = ctx.primitives.int32; - for (i, arg) in args.iter().enumerate() { - if arg.0 == Some("start".into()) { - start = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if arg.0 == Some("stop".into()) { - stop = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if arg.0 == Some("step".into()) { - step = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 0 { - start = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 1 { - stop = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } else if i == 2 { - step = Some( - arg.1 - .clone() - .to_basic_value_enum(ctx, generator, ty_i32)? - .into_int_value(), - ); - } - } - let step = match step { - Some(step) => { - // assert step != 0, throw exception if not - let not_zero = ctx - .builder - .build_int_compare( - IntPredicate::NE, - step, - step.get_type().const_zero(), - "range_step_ne", - ) - .unwrap(); - ctx.make_assert( - generator, - not_zero, - "0:ValueError", - "range() step must not be zero", - [None, None, None], - ctx.current_loc, - ); - step - } - None => int32.const_int(1, false), - }; - let stop = stop.unwrap_or_else(|| { - let v = start.unwrap(); - start = None; - v - }); - let start = start.unwrap_or_else(|| int32.const_zero()); - - let ptr = RangeType::new(ctx.ctx).new_value(generator, ctx, Some("range")); - ptr.store_start(ctx, start); - ptr.store_end(ctx, stop); - ptr.store_step(ctx, step); - Ok(Some(ptr.as_base_value().into())) - }, - )))), - loc: None, - } - } - /// Build the `str()` function. fn build_str_function(&mut self) -> TopLevelDef { let prim = PrimDef::FunStr; diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 9fef74ebf..76d454cff 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -49,7 +49,7 @@ pub enum PrimDef { FunRound, FunRound64, FunNpRound, - FunRange, + FunRangeInit, FunStr, FunBool, FunFloor, @@ -203,7 +203,7 @@ impl PrimDef { PrimDef::FunRound => fun("round", None), PrimDef::FunRound64 => fun("round64", None), PrimDef::FunNpRound => fun("np_round", None), - PrimDef::FunRange => fun("range", None), + PrimDef::FunRangeInit => fun("range.__init__", Some("__init__")), PrimDef::FunStr => fun("str", None), PrimDef::FunBool => fun("bool", None), PrimDef::FunFloor => fun("floor", None),