Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
mwojcik | 1c97d9dd11 |
|
@ -2250,7 +2250,7 @@ pub fn call_np_linalg_matrix_power<'ctx, G: CodeGenerator + ?Sized>(
|
|||
ctx,
|
||||
generator,
|
||||
&llvm_usize.const_zero(),
|
||||
n2.as_basic_value_enum(),
|
||||
n2.as_basic_value_enum(),
|
||||
);
|
||||
};
|
||||
let n2_array = n2_array.as_base_value().as_basic_value_enum();
|
||||
|
@ -2401,3 +2401,47 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
|||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||
}
|
||||
}
|
||||
|
||||
/// Invokes the ``np_arange`` function
|
||||
pub fn call_np_arange<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
start: (Type, BasicValueEnum<'ctx>),
|
||||
stop: (Type, BasicValueEnum<'ctx>),
|
||||
step: (Type, BasicValueEnum<'ctx>),
|
||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||
const FN_NAME: &str = "np_arange";
|
||||
|
||||
let (start_ty, start) = start;
|
||||
let (stop_ty, stop) = stop;
|
||||
let (step_ty, step) = step;
|
||||
|
||||
// verify start/stop are the same type (or cast to float)
|
||||
// step can be float
|
||||
// stop is not included in range
|
||||
// generate an array based on it, ez
|
||||
|
||||
// return type is int only if both start and step args are int
|
||||
let mut actual_step_ty = match (start, step) {
|
||||
BasicValueEnum::IntType(_), BasicValueEnum::IntType(_) |
|
||||
BasicValueEnum::FloatType(_), BasicValueEnum::IntType(_) => start_ty,
|
||||
BasicValueEnum::IntType(_), BasicValueEnum::FloatType(_) |
|
||||
BasicValueEnum::FloatType(_), BasicValueEnum::FloatType(_) => step_ty,
|
||||
_ => unsupported_type(ctx, FN_NAME, &[start])
|
||||
};
|
||||
|
||||
let mut out = numpy::create_ndarray_const_shape(generator, ctx, actual_step_ty, &[])
|
||||
.unwrap()
|
||||
.as_base_value()
|
||||
.as_basic_value_enum();
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
Some(FN_NAME),
|
||||
|init|,
|
||||
|cond|,
|
||||
|body|,
|
||||
|update|,
|
||||
)
|
||||
}
|
|
@ -580,6 +580,8 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
self.build_np_sp_ndarray_function(prim)
|
||||
}
|
||||
|
||||
PrimDef::FunNpArange => self.build_arange_method(),
|
||||
|
||||
PrimDef::FunNpDot
|
||||
| PrimDef::FunNpLinalgCholesky
|
||||
| PrimDef::FunNpLinalgQr
|
||||
|
@ -1923,6 +1925,118 @@ impl<'a> BuiltinBuilder<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Build `np_arange` function
|
||||
fn build_np_arange_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||
// returns ndarray with either ints or floats, depending on args
|
||||
let ndarray_float =
|
||||
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
||||
TopLevelDef::Function {
|
||||
name: prim.name().into(),
|
||||
simple_name: prim.simple_name().into(),
|
||||
signature: make_ctor_signature(self.unifier),
|
||||
var_id: Vec::default(),
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, _, args, generator| {
|
||||
let (zelf_ty, zelf) = obj.unwrap();
|
||||
let zelf =
|
||||
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
|
||||
let zelf = RangeValue::from_ptr_val(zelf, Some("range"));
|
||||
|
||||
let mut start = None;
|
||||
let mut stop = None;
|
||||
let mut step = None;
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let ty_i32 = ctx.primitives.int32;
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
if arg.0 == Some("start".into()) {
|
||||
start = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
} else if arg.0 == Some("stop".into()) {
|
||||
stop = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
} else if arg.0 == Some("step".into()) {
|
||||
step = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
} else if i == 0 {
|
||||
start = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
} else if i == 1 {
|
||||
stop = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
} else if i == 2 {
|
||||
step = Some(
|
||||
arg.1
|
||||
.clone()
|
||||
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||
.into_int_value(),
|
||||
);
|
||||
}
|
||||
}
|
||||
let step = match step {
|
||||
Some(step) => {
|
||||
// assert step != 0, throw exception if not
|
||||
let not_zero = ctx
|
||||
.builder
|
||||
.build_int_compare(
|
||||
IntPredicate::NE,
|
||||
step,
|
||||
step.get_type().const_zero(),
|
||||
"range_step_ne",
|
||||
)
|
||||
.unwrap();
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
not_zero,
|
||||
"0:ValueError",
|
||||
"range() step must not be zero",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
step
|
||||
}
|
||||
None => int32.const_int(1, false),
|
||||
};
|
||||
let stop = stop.unwrap_or_else(|| {
|
||||
let v = start.unwrap();
|
||||
start = None;
|
||||
v
|
||||
});
|
||||
let start = start.unwrap_or_else(|| int32.const_zero());
|
||||
|
||||
zelf.store_start(ctx, start);
|
||||
zelf.store_end(ctx, stop);
|
||||
zelf.store_step(ctx, step);
|
||||
|
||||
Ok(Some(builtin_fns::call_np_arange(generator, ctx, (start_ty, start), (stop_ty, stop), (step_ty, step))))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build `np_linalg` and `sp_linalg` functions
|
||||
///
|
||||
/// The input to these functions must be floating point `NDArray`
|
||||
|
|
|
@ -102,6 +102,7 @@ pub enum PrimDef {
|
|||
FunNpNextAfter,
|
||||
FunNpTranspose,
|
||||
FunNpReshape,
|
||||
FunNpArange,
|
||||
|
||||
// Linalg functions
|
||||
FunNpDot,
|
||||
|
@ -288,6 +289,7 @@ impl PrimDef {
|
|||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||
PrimDef::FunNpArange => fun("np_arange", None)
|
||||
|
||||
// Linalg functions
|
||||
PrimDef::FunNpDot => fun("np_dot", None),
|
||||
|
|
Loading…
Reference in New Issue