diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 5608ec71..fe350b09 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -2250,7 +2250,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>( ctx, generator, &llvm_usize.const_zero(), - n2.as_basic_value_enum(), + n2.as_basic_value_enum(), ); }; let n2_array = n2_array.as_base_value().as_basic_value_enum(); @@ -2401,3 +2401,47 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>( unsupported_type(ctx, FN_NAME, &[x1_ty]) } } + +/// Invokes the ``np_arange`` function +pub fn call_np_arange<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + start: (Type, BasicValueEnum<'ctx>), + stop: (Type, BasicValueEnum<'ctx>), + step: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + const FN_NAME: &str = "np_arange"; + + let (start_ty, start) = start; + let (stop_ty, stop) = stop; + let (step_ty, step) = step; + + // verify start/stop are the same type (or cast to float) + // step can be float + // stop is not included in range + // generate an array based on it, ez + + // return type is int only if both start and step args are int + let mut actual_step_ty = match (start, step) { + BasicValueEnum::IntType(_), BasicValueEnum::IntType(_) | + BasicValueEnum::FloatType(_), BasicValueEnum::IntType(_) => start_ty, + BasicValueEnum::IntType(_), BasicValueEnum::FloatType(_) | + BasicValueEnum::FloatType(_), BasicValueEnum::FloatType(_) => step_ty, + _ => unsupported_type(ctx, FN_NAME, &[start]) + }; + + let mut out = numpy::create_ndarray_const_shape(generator, ctx, actual_step_ty, &[]) + .unwrap() + .as_base_value() + .as_basic_value_enum(); + + gen_for_callback( + generator, + ctx, + Some(FN_NAME), + |init|, + |cond|, + |body|, + |update|, + ) +} \ No newline at end of file diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index e2325ebb..feb221f5 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -580,6 +580,8 @@ impl<'a> BuiltinBuilder<'a> { self.build_np_sp_ndarray_function(prim) } + PrimDef::FunNpArange => self.build_arange_method(), + PrimDef::FunNpDot | PrimDef::FunNpLinalgCholesky | PrimDef::FunNpLinalgQr @@ -1923,6 +1925,118 @@ impl<'a> BuiltinBuilder<'a> { } } + /// Build `np_arange` function + fn build_np_arange_function(&mut self, prim: PrimDef) -> TopLevelDef { + // returns ndarray with either ints or floats, depending on args + let ndarray_float = + make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty)); + TopLevelDef::Function { + name: prim.name().into(), + simple_name: prim.simple_name().into(), + signature: make_ctor_signature(self.unifier), + var_id: Vec::default(), + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, args, generator| { + let (zelf_ty, zelf) = obj.unwrap(); + let zelf = + zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value(); + let zelf = RangeValue::from_ptr_val(zelf, Some("range")); + + let mut start = None; + let mut stop = None; + let mut step = None; + let int32 = ctx.ctx.i32_type(); + let ty_i32 = ctx.primitives.int32; + for (i, arg) in args.iter().enumerate() { + if arg.0 == Some("start".into()) { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("stop".into()) { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if arg.0 == Some("step".into()) { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 0 { + start = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 1 { + stop = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } else if i == 2 { + step = Some( + arg.1 + .clone() + .to_basic_value_enum(ctx, generator, ty_i32)? + .into_int_value(), + ); + } + } + let step = match step { + Some(step) => { + // assert step != 0, throw exception if not + let not_zero = ctx + .builder + .build_int_compare( + IntPredicate::NE, + step, + step.get_type().const_zero(), + "range_step_ne", + ) + .unwrap(); + ctx.make_assert( + generator, + not_zero, + "0:ValueError", + "range() step must not be zero", + [None, None, None], + ctx.current_loc, + ); + step + } + None => int32.const_int(1, false), + }; + let stop = stop.unwrap_or_else(|| { + let v = start.unwrap(); + start = None; + v + }); + let start = start.unwrap_or_else(|| int32.const_zero()); + + zelf.store_start(ctx, start); + zelf.store_end(ctx, stop); + zelf.store_step(ctx, step); + + Ok(Some(builtin_fns::call_np_arange(generator, ctx, (start_ty, start), (stop_ty, stop), (step_ty, step)))) + }, + )))), + loc: None, + } + } + /// Build `np_linalg` and `sp_linalg` functions /// /// The input to these functions must be floating point `NDArray` diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 29a662c5..fe6c761a 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -102,6 +102,7 @@ pub enum PrimDef { FunNpNextAfter, FunNpTranspose, FunNpReshape, + FunNpArange, // Linalg functions FunNpDot, @@ -288,6 +289,7 @@ impl PrimDef { PrimDef::FunNpNextAfter => fun("np_nextafter", None), PrimDef::FunNpTranspose => fun("np_transpose", None), PrimDef::FunNpReshape => fun("np_reshape", None), + PrimDef::FunNpArange => fun("np_arange", None) // Linalg functions PrimDef::FunNpDot => fun("np_dot", None),