Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
mwojcik | 1c97d9dd11 |
|
@ -2401,3 +2401,47 @@ pub fn call_sp_linalg_hessenberg<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
unsupported_type(ctx, FN_NAME, &[x1_ty])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Invokes the ``np_arange`` function
|
||||||
|
pub fn call_np_arange<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
start: (Type, BasicValueEnum<'ctx>),
|
||||||
|
stop: (Type, BasicValueEnum<'ctx>),
|
||||||
|
step: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "np_arange";
|
||||||
|
|
||||||
|
let (start_ty, start) = start;
|
||||||
|
let (stop_ty, stop) = stop;
|
||||||
|
let (step_ty, step) = step;
|
||||||
|
|
||||||
|
// verify start/stop are the same type (or cast to float)
|
||||||
|
// step can be float
|
||||||
|
// stop is not included in range
|
||||||
|
// generate an array based on it, ez
|
||||||
|
|
||||||
|
// return type is int only if both start and step args are int
|
||||||
|
let mut actual_step_ty = match (start, step) {
|
||||||
|
BasicValueEnum::IntType(_), BasicValueEnum::IntType(_) |
|
||||||
|
BasicValueEnum::FloatType(_), BasicValueEnum::IntType(_) => start_ty,
|
||||||
|
BasicValueEnum::IntType(_), BasicValueEnum::FloatType(_) |
|
||||||
|
BasicValueEnum::FloatType(_), BasicValueEnum::FloatType(_) => step_ty,
|
||||||
|
_ => unsupported_type(ctx, FN_NAME, &[start])
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut out = numpy::create_ndarray_const_shape(generator, ctx, actual_step_ty, &[])
|
||||||
|
.unwrap()
|
||||||
|
.as_base_value()
|
||||||
|
.as_basic_value_enum();
|
||||||
|
|
||||||
|
gen_for_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
Some(FN_NAME),
|
||||||
|
|init|,
|
||||||
|
|cond|,
|
||||||
|
|body|,
|
||||||
|
|update|,
|
||||||
|
)
|
||||||
|
}
|
|
@ -580,6 +580,8 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
self.build_np_sp_ndarray_function(prim)
|
self.build_np_sp_ndarray_function(prim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PrimDef::FunNpArange => self.build_arange_method(),
|
||||||
|
|
||||||
PrimDef::FunNpDot
|
PrimDef::FunNpDot
|
||||||
| PrimDef::FunNpLinalgCholesky
|
| PrimDef::FunNpLinalgCholesky
|
||||||
| PrimDef::FunNpLinalgQr
|
| PrimDef::FunNpLinalgQr
|
||||||
|
@ -1923,6 +1925,118 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build `np_arange` function
|
||||||
|
fn build_np_arange_function(&mut self, prim: PrimDef) -> TopLevelDef {
|
||||||
|
// returns ndarray with either ints or floats, depending on args
|
||||||
|
let ndarray_float =
|
||||||
|
make_ndarray_ty(self.unifier, self.primitives, Some(float), Some(common_ndim.ty));
|
||||||
|
TopLevelDef::Function {
|
||||||
|
name: prim.name().into(),
|
||||||
|
simple_name: prim.simple_name().into(),
|
||||||
|
signature: make_ctor_signature(self.unifier),
|
||||||
|
var_id: Vec::default(),
|
||||||
|
instance_to_symbol: HashMap::default(),
|
||||||
|
instance_to_stmt: HashMap::default(),
|
||||||
|
resolver: None,
|
||||||
|
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||||
|
|ctx, obj, _, args, generator| {
|
||||||
|
let (zelf_ty, zelf) = obj.unwrap();
|
||||||
|
let zelf =
|
||||||
|
zelf.to_basic_value_enum(ctx, generator, zelf_ty)?.into_pointer_value();
|
||||||
|
let zelf = RangeValue::from_ptr_val(zelf, Some("range"));
|
||||||
|
|
||||||
|
let mut start = None;
|
||||||
|
let mut stop = None;
|
||||||
|
let mut step = None;
|
||||||
|
let int32 = ctx.ctx.i32_type();
|
||||||
|
let ty_i32 = ctx.primitives.int32;
|
||||||
|
for (i, arg) in args.iter().enumerate() {
|
||||||
|
if arg.0 == Some("start".into()) {
|
||||||
|
start = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
} else if arg.0 == Some("stop".into()) {
|
||||||
|
stop = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
} else if arg.0 == Some("step".into()) {
|
||||||
|
step = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
} else if i == 0 {
|
||||||
|
start = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
} else if i == 1 {
|
||||||
|
stop = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
} else if i == 2 {
|
||||||
|
step = Some(
|
||||||
|
arg.1
|
||||||
|
.clone()
|
||||||
|
.to_basic_value_enum(ctx, generator, ty_i32)?
|
||||||
|
.into_int_value(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let step = match step {
|
||||||
|
Some(step) => {
|
||||||
|
// assert step != 0, throw exception if not
|
||||||
|
let not_zero = ctx
|
||||||
|
.builder
|
||||||
|
.build_int_compare(
|
||||||
|
IntPredicate::NE,
|
||||||
|
step,
|
||||||
|
step.get_type().const_zero(),
|
||||||
|
"range_step_ne",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
not_zero,
|
||||||
|
"0:ValueError",
|
||||||
|
"range() step must not be zero",
|
||||||
|
[None, None, None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
step
|
||||||
|
}
|
||||||
|
None => int32.const_int(1, false),
|
||||||
|
};
|
||||||
|
let stop = stop.unwrap_or_else(|| {
|
||||||
|
let v = start.unwrap();
|
||||||
|
start = None;
|
||||||
|
v
|
||||||
|
});
|
||||||
|
let start = start.unwrap_or_else(|| int32.const_zero());
|
||||||
|
|
||||||
|
zelf.store_start(ctx, start);
|
||||||
|
zelf.store_end(ctx, stop);
|
||||||
|
zelf.store_step(ctx, step);
|
||||||
|
|
||||||
|
Ok(Some(builtin_fns::call_np_arange(generator, ctx, (start_ty, start), (stop_ty, stop), (step_ty, step))))
|
||||||
|
},
|
||||||
|
)))),
|
||||||
|
loc: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Build `np_linalg` and `sp_linalg` functions
|
/// Build `np_linalg` and `sp_linalg` functions
|
||||||
///
|
///
|
||||||
/// The input to these functions must be floating point `NDArray`
|
/// The input to these functions must be floating point `NDArray`
|
||||||
|
|
|
@ -102,6 +102,7 @@ pub enum PrimDef {
|
||||||
FunNpNextAfter,
|
FunNpNextAfter,
|
||||||
FunNpTranspose,
|
FunNpTranspose,
|
||||||
FunNpReshape,
|
FunNpReshape,
|
||||||
|
FunNpArange,
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
FunNpDot,
|
FunNpDot,
|
||||||
|
@ -288,6 +289,7 @@ impl PrimDef {
|
||||||
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
|
||||||
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
PrimDef::FunNpTranspose => fun("np_transpose", None),
|
||||||
PrimDef::FunNpReshape => fun("np_reshape", None),
|
PrimDef::FunNpReshape => fun("np_reshape", None),
|
||||||
|
PrimDef::FunNpArange => fun("np_arange", None)
|
||||||
|
|
||||||
// Linalg functions
|
// Linalg functions
|
||||||
PrimDef::FunNpDot => fun("np_dot", None),
|
PrimDef::FunNpDot => fun("np_dot", None),
|
||||||
|
|
Loading…
Reference in New Issue