Implement ndarray class, constructor and creation functions #371
|
@ -400,6 +400,9 @@ fn gen_rpc_tag(
|
|||
buffer.push(b'l');
|
||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||
}
|
||||
TNDArray { .. } => {
|
||||
todo!()
|
||||
}
|
||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||
}
|
||||
}
|
||||
|
@ -673,6 +676,14 @@ pub fn attributes_writeback(
|
|||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||
}
|
||||
},
|
||||
TypeEnum::TNDArray { ty: elem_ty, .. } => {
|
||||
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
||||
let pydict = PyDict::new(py);
|
||||
pydict.set_item("obj", val)?;
|
||||
host_attributes.append(pydict)?;
|
||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||
}
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,6 +63,17 @@ enum Isa {
|
|||
CortexA9,
|
||||
}
|
||||
|
||||
impl Isa {
|
||||
/// Returns the number of bits in `size_t` for the [`Isa`].
|
||||
fn get_size_type(&self) -> u32 {
|
||||
if self == &Isa::Host {
|
||||
64u32
|
||||
} else {
|
||||
32u32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PrimitivePythonId {
|
||||
int: u64,
|
||||
|
@ -74,6 +85,7 @@ pub struct PrimitivePythonId {
|
|||
float64: u64,
|
||||
bool: u64,
|
||||
list: u64,
|
||||
ndarray: u64,
|
||||
tuple: u64,
|
||||
typevar: u64,
|
||||
const_generic_marker: u64,
|
||||
|
@ -277,9 +289,11 @@ impl Nac3 {
|
|||
py: Python,
|
||||
link_fn: &dyn Fn(&Module) -> PyResult<T>,
|
||||
) -> PyResult<T> {
|
||||
let size_t = self.isa.get_size_type();
|
||||
let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new(
|
||||
self.builtins.clone(),
|
||||
ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" },
|
||||
size_t,
|
||||
);
|
||||
|
||||
let builtins = PyModule::import(py, "builtins")?;
|
||||
|
@ -792,7 +806,7 @@ impl Nac3 {
|
|||
Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS,
|
||||
Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS,
|
||||
};
|
||||
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
|
||||
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0;
|
||||
let builtins = vec![
|
||||
(
|
||||
"now_mu".into(),
|
||||
|
@ -866,6 +880,7 @@ impl Nac3 {
|
|||
float: get_attr_id(builtins_mod, "float"),
|
||||
float64: get_attr_id(numpy_mod, "float64"),
|
||||
list: get_attr_id(builtins_mod, "list"),
|
||||
ndarray: get_attr_id(numpy_mod, "NDArray"),
|
||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||
exception: get_attr_id(builtins_mod, "Exception"),
|
||||
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||
|
|
|
@ -302,6 +302,12 @@ impl InnerResolver {
|
|||
let var = unifier.get_dummy_var().0;
|
||||
let list = unifier.add_ty(TypeEnum::TList { ty: var });
|
||||
Ok(Ok((list, false)))
|
||||
} else if ty_id == self.primitive_ids.ndarray {
|
||||
// do not handle type var param and concrete check here
|
||||
let var = unifier.get_dummy_var().0;
|
||||
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0;
|
||||
let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims });
|
||||
Ok(Ok((ndarray, false)))
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
// do not handle type var param and concrete check here
|
||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||
|
@ -446,6 +452,16 @@ impl InnerResolver {
|
|||
)));
|
||||
}
|
||||
}
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
if args.len() != 2 {
|
||||
return Ok(Err(format!(
|
||||
"type list needs exactly 2 type parameters, found {}",
|
||||
args.len()
|
||||
)));
|
||||
}
|
||||
|
||||
todo!()
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let args = match args
|
||||
.iter()
|
||||
|
@ -607,7 +623,7 @@ impl InnerResolver {
|
|||
Err(e) => return Ok(Err(e)),
|
||||
};
|
||||
match (&*unifier.get_ty(extracted_ty), inst_check) {
|
||||
// do the instantiation for these three types
|
||||
// do the instantiation for these four types
|
||||
(TypeEnum::TList { ty }, false) => {
|
||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if len == 0 {
|
||||
|
@ -632,6 +648,30 @@ impl InnerResolver {
|
|||
}
|
||||
}
|
||||
}
|
||||
(TypeEnum::TNDArray { ty, ndims }, false) => {
|
||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if len == 0 {
|
||||
assert!(matches!(
|
||||
&*unifier.get_ty(*ty),
|
||||
TypeEnum::TVar { fields: None, range, .. }
|
||||
if range.is_empty()
|
||||
));
|
||||
Ok(Ok(extracted_ty))
|
||||
} else {
|
||||
let actual_ty =
|
||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||
match actual_ty {
|
||||
Ok(t) => match unifier.unify(*ty, t) {
|
||||
Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
|
||||
Err(e) => Ok(Err(format!(
|
||||
"type error ({}) for the ndarray",
|
||||
e.to_display(unifier).to_string()
|
||||
))),
|
||||
},
|
||||
Err(e) => Ok(Err(e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
(TypeEnum::TTuple { .. }, false) => {
|
||||
let elements: &PyTuple = obj.downcast()?;
|
||||
let types: Result<Result<Vec<_>, _>, _> = elements
|
||||
|
@ -898,6 +938,8 @@ impl InnerResolver {
|
|||
global.set_initializer(&val);
|
||||
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.ndarray {
|
||||
todo!()
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else {
|
||||
|
|
|
@ -17,6 +17,7 @@ fn main() {
|
|||
const FLAG: &[&str] = &[
|
||||
"--target=wasm32",
|
||||
FILE,
|
||||
"-fno-discard-value-names",
|
||||
"-O3",
|
||||
"-emit-llvm",
|
||||
"-S",
|
||||
|
|
|
@ -47,6 +47,10 @@ pub enum ConcreteTypeEnum {
|
|||
TList {
|
||||
ty: ConcreteType,
|
||||
},
|
||||
TNDArray {
|
||||
ty: ConcreteType,
|
||||
ndims: ConcreteType,
|
||||
},
|
||||
TObj {
|
||||
obj_id: DefinitionId,
|
||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||
|
@ -167,6 +171,10 @@ impl ConcreteTypeStore {
|
|||
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
|
||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||
},
|
||||
TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray {
|
||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||
ndims: self.from_unifier_type(unifier, primitives, *ndims, cache),
|
||||
},
|
||||
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: fields
|
||||
|
@ -260,6 +268,12 @@ impl ConcreteTypeStore {
|
|||
ConcreteTypeEnum::TList { ty } => {
|
||||
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||
}
|
||||
ConcreteTypeEnum::TNDArray { ty, ndims } => {
|
||||
TypeEnum::TNDArray {
|
||||
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
|
||||
ndims: self.to_unifier_type(unifier, primitives, *ndims, cache),
|
||||
}
|
||||
}
|
||||
ConcreteTypeEnum::TVirtual { ty } => {
|
||||
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||
}
|
||||
|
|
|
@ -1846,6 +1846,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
ctx.build_gep_and_load(arr_ptr, &[index], None).into()
|
||||
}
|
||||
}
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||
}
|
||||
TypeEnum::TTuple { .. } => {
|
||||
let index: u32 =
|
||||
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
|
||||
|
|
|
@ -92,6 +92,18 @@ pub trait CodeGenerator {
|
|||
gen_var(ctx, ty, name)
|
||||
}
|
||||
|
||||
/// Allocate memory for a variable and return a pointer pointing to it.
|
||||
/// The default implementation places the allocations at the start of the function.
|
||||
fn gen_array_var_alloc<'ctx, 'a>(
|
||||
&mut self,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ty: BasicTypeEnum<'ctx>,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
gen_array_var(ctx, ty, size, name)
|
||||
}
|
||||
|
||||
/// Return a pointer pointing to the target of the expression.
|
||||
fn gen_store_target<'ctx>(
|
||||
&mut self,
|
||||
|
|
|
@ -196,4 +196,76 @@ double __nac3_j0(double x) {
|
|||
}
|
||||
|
||||
return j0(x);
|
||||
}
|
||||
|
||||
uint32_t __nac3_ndarray_calc_size(
|
||||
const uint64_t *list_data,
|
||||
uint32_t list_len
|
||||
) {
|
||||
uint32_t num_elems = 1;
|
||||
for (uint32_t i = 0; i < list_len; ++i) {
|
||||
uint64_t val = list_data[i];
|
||||
__builtin_assume(val >= 0);
|
||||
num_elems *= list_data[i];
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
uint64_t __nac3_ndarray_calc_size64(
|
||||
const uint64_t *list_data,
|
||||
uint64_t list_len
|
||||
) {
|
||||
uint64_t num_elems = 1;
|
||||
for (uint64_t i = 0; i < list_len; ++i) {
|
||||
uint64_t val = list_data[i];
|
||||
__builtin_assume(val >= 0);
|
||||
num_elems *= list_data[i];
|
||||
}
|
||||
return num_elems;
|
||||
}
|
||||
|
||||
void __nac3_ndarray_init_dims(
|
||||
uint32_t *ndarray_dims,
|
||||
const int32_t *shape_data,
|
||||
uint32_t shape_len
|
||||
) {
|
||||
__builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t));
|
||||
}
|
||||
|
||||
void __nac3_ndarray_init_dims64(
|
||||
uint64_t *ndarray_dims,
|
||||
const int32_t *shape_data,
|
||||
uint64_t shape_len
|
||||
) {
|
||||
for (uint64_t i = 0; i < shape_len; ++i) {
|
||||
ndarray_dims[i] = (uint64_t) shape_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices(
|
||||
uint32_t index,
|
||||
const uint32_t* dims,
|
||||
uint32_t num_dims,
|
||||
uint32_t* idxs
|
||||
) {
|
||||
uint32_t stride = 1;
|
||||
for (uint32_t dim = 0; dim < num_dims; dim++) {
|
||||
uint32_t i = num_dims - dim - 1;
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
||||
|
||||
void __nac3_ndarray_calc_nd_indices64(
|
||||
uint64_t index,
|
||||
const uint64_t* dims,
|
||||
uint64_t num_dims,
|
||||
uint64_t* idxs
|
||||
) {
|
||||
uint64_t stride = 1;
|
||||
for (uint64_t dim = 0; dim < num_dims; dim++) {
|
||||
uint64_t i = num_dims - dim - 1;
|
||||
idxs[i] = (index / stride) % dims[i];
|
||||
stride *= dims[i];
|
||||
}
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use crate::typecheck::typedef::Type;
|
||||
|
||||
use super::{CodeGenContext, CodeGenerator};
|
||||
use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
context::Context,
|
||||
|
@ -546,3 +546,181 @@ pub fn call_j0<'ctx>(
|
|||
.unwrap_left()
|
||||
.into_float_value()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the
|
||||
/// calculated total size.
|
||||
///
|
||||
/// * `num_dims` - An [IntValue] containing the number of dimensions.
|
||||
/// * `dims` - A [PointerValue] to an array containing the size of each dimensions.
|
||||
pub fn call_ndarray_calc_size<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
num_dims: IntValue<'ctx>,
|
||||
dims: PointerValue<'ctx>,
|
||||
) -> IntValue<'ctx> {
|
||||
let llvm_i64 = ctx.ctx.i64_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_size",
|
||||
64 => "__nac3_ndarray_calc_size64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
|
||||
&[
|
||||
llvm_pi64.into(),
|
||||
llvm_usize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
|
||||
.unwrap_or_else(|| {
|
||||
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
|
||||
});
|
||||
|
||||
ctx.builder
|
||||
.build_call(
|
||||
ndarray_calc_size_fn,
|
||||
&[
|
||||
dims.into(),
|
||||
num_dims.into(),
|
||||
],
|
||||
"",
|
||||
)
|
||||
.try_as_basic_value()
|
||||
.unwrap_left()
|
||||
.into_int_value()
|
||||
}
|
||||
|
||||
/// Generates a call to `__nac3_ndarray_init_dims`.
|
||||
///
|
||||
/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an
|
||||
/// `NDArray`.
|
||||
/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM
|
||||
/// representation of a `list`.
|
||||
pub fn call_ndarray_init_dims<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
shape: PointerValue<'ctx>,
|
||||
) {
|
||||
assert_is_ndarray(ndarray);
|
||||
assert_is_list(shape);
|
||||
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => "__nac3_ndarray_init_dims",
|
||||
64 => "__nac3_ndarray_init_dims64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_pusize.into(),
|
||||
llvm_pi32.into(),
|
||||
llvm_usize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
);
|
||||
let shape_data = ctx.build_gep_and_load(
|
||||
shape,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None
|
||||
);
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
ctx.builder.build_call(
|
||||
ndarray_init_dims_fn,
|
||||
&[
|
||||
ndarray_dims.into(),
|
||||
shape_data.into(),
|
||||
ndarray_num_dims.into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
}
|
||||
|
||||
pub fn call_ndarray_calc_nd_indices<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
index: IntValue<'ctx>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert_is_ndarray(ndarray);
|
||||
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
|
||||
|
||||
let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() {
|
||||
32 => "__nac3_ndarray_calc_nd_indices",
|
||||
64 => "__nac3_ndarray_calc_nd_indices64",
|
||||
bw => unreachable!("Unsupported size type bit width: {}", bw)
|
||||
};
|
||||
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_pusize.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None)
|
||||
});
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
let ndarray_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
|
||||
let indices = ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
ndarray_num_dims,
|
||||
"",
|
||||
);
|
||||
|
||||
ctx.builder.build_call(
|
||||
ndarray_calc_nd_indices_fn,
|
||||
&[
|
||||
index.into(),
|
||||
ndarray_dims.into(),
|
||||
ndarray_num_dims.into(),
|
||||
indices.into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
Ok(indices)
|
||||
}
|
|
@ -34,6 +34,9 @@ use std::sync::{
|
|||
};
|
||||
use std::thread;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
use inkwell::types::AnyTypeEnum;
|
||||
|
||||
pub mod concrete_type;
|
||||
pub mod expr;
|
||||
mod generator;
|
||||
|
@ -236,7 +239,7 @@ pub struct WorkerRegistry {
|
|||
static_value_store: Arc<Mutex<StaticValueStore>>,
|
||||
|
||||
/// LLVM-related options for code generation.
|
||||
llvm_options: CodeGenLLVMOptions,
|
||||
pub llvm_options: CodeGenLLVMOptions,
|
||||
}
|
||||
|
||||
impl WorkerRegistry {
|
||||
|
@ -507,6 +510,24 @@ fn get_llvm_type<'ctx>(
|
|||
];
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||
}
|
||||
TNDArray { ty, .. } => {
|
||||
let llvm_usize = generator.get_size_type(ctx);
|
||||
let element_type = get_llvm_type(
|
||||
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
|
||||
);
|
||||
|
||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
|
||||
//
|
||||
// * num_dims: Number of dimensions in the array
|
||||
// * dims: Pointer to an array containing the size of each dimension
|
||||
// * data: Pointer to an array containing the array data
|
||||
let fields = [
|
||||
llvm_usize.into(),
|
||||
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
||||
element_type.ptr_type(AddressSpace::default()).into(),
|
||||
];
|
||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||
}
|
||||
TVirtual { .. } => unimplemented!(),
|
||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||
};
|
||||
|
@ -614,6 +635,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
|
|||
str: unifier.get_representative(primitives.str),
|
||||
exception: unifier.get_representative(primitives.exception),
|
||||
option: unifier.get_representative(primitives.option),
|
||||
..primitives
|
||||
};
|
||||
|
||||
let mut type_cache: HashMap<_, _> = [
|
||||
|
@ -976,3 +998,43 @@ fn gen_in_range_check<'ctx>(
|
|||
|
||||
ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp")
|
||||
}
|
||||
|
||||
/// Checks whether the pointer `value` refers to a `list` in LLVM.
|
||||
fn assert_is_list(value: PointerValue) -> PointerValue {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let llvm_shape_ty = value.get_type().get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else {
|
||||
panic!("Expected struct type for `list` type, but got {llvm_shape_ty}")
|
||||
};
|
||||
assert_eq!(llvm_shape_ty.count_fields(), 2);
|
||||
assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..))));
|
||||
assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..))));
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
||||
/// Checks whether the pointer `value` refers to an `NDArray` in LLVM.
|
||||
fn assert_is_ndarray(value: PointerValue) -> PointerValue {
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let llvm_ndarray_ty = value.get_type().get_element_type();
|
||||
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
|
||||
panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}")
|
||||
};
|
||||
|
||||
assert_eq!(llvm_ndarray_ty.count_fields(), 3);
|
||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..))));
|
||||
let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else {
|
||||
unreachable!()
|
||||
};
|
||||
let BasicTypeEnum::PointerType(dims) = ndarray_dims else {
|
||||
panic!("Expected pointer type for `list.1`, but got {ndarray_dims}")
|
||||
};
|
||||
assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..)));
|
||||
assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..))));
|
||||
}
|
||||
|
||||
value
|
||||
}
|
||||
|
|
|
@ -15,8 +15,8 @@ use crate::{
|
|||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
basic_block::BasicBlock,
|
||||
types::BasicTypeEnum,
|
||||
values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue},
|
||||
types::{BasicType, BasicTypeEnum},
|
||||
values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue},
|
||||
IntPredicate,
|
||||
};
|
||||
use nac3parser::ast::{
|
||||
|
@ -54,6 +54,37 @@ pub fn gen_var<'ctx>(
|
|||
Ok(ptr)
|
||||
}
|
||||
|
||||
/// See [CodeGenerator::gen_array_var_alloc].
|
||||
pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>(
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ty: T,
|
||||
size: IntValue<'ctx>,
|
||||
name: Option<&str>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
// Restore debug location
|
||||
let di_loc = ctx.debug_info.0.create_debug_location(
|
||||
ctx.ctx,
|
||||
ctx.current_loc.row as u32,
|
||||
ctx.current_loc.column as u32,
|
||||
ctx.debug_info.2,
|
||||
None,
|
||||
);
|
||||
|
||||
// put the alloca in init block
|
||||
let current = ctx.builder.get_insert_block().unwrap();
|
||||
|
||||
// position before the last branching instruction...
|
||||
ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap());
|
||||
ctx.builder.set_current_debug_location(di_loc);
|
||||
|
||||
let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or(""));
|
||||
|
||||
ctx.builder.position_at_end(current);
|
||||
ctx.builder.set_current_debug_location(di_loc);
|
||||
|
||||
Ok(ptr)
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_store_target`].
|
||||
pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
@ -99,63 +130,69 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
|||
}
|
||||
}
|
||||
ExprKind::Subscript { value, slice, .. } => {
|
||||
assert!(matches!(
|
||||
ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref(),
|
||||
TypeEnum::TList { .. },
|
||||
));
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let zero = i32_type.const_zero();
|
||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let len = ctx
|
||||
.build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len"))
|
||||
.into_int_value();
|
||||
let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let raw_index = ctx.builder.build_int_s_extend(
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx),
|
||||
"sext",
|
||||
);
|
||||
// handle negative index
|
||||
let is_negative = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
"is_neg",
|
||||
);
|
||||
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted");
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_select(is_negative, adjusted, raw_index, "index")
|
||||
.into_int_value();
|
||||
// unsigned less than is enough, because negative index after adjustment is
|
||||
// bigger than the length (for unsigned cmp)
|
||||
let bound_check = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
index,
|
||||
len,
|
||||
"inbound",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
bound_check,
|
||||
"0:IndexError",
|
||||
"index {0} out of bounds 0:{1}",
|
||||
[Some(raw_index), Some(len), None],
|
||||
slice.location,
|
||||
);
|
||||
unsafe {
|
||||
let arr_ptr = ctx
|
||||
.build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr"))
|
||||
.into_pointer_value();
|
||||
ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or(""))
|
||||
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
|
||||
TypeEnum::TList { .. } => {
|
||||
let i32_type = ctx.ctx.i32_type();
|
||||
let zero = i32_type.const_zero();
|
||||
let v = generator
|
||||
.gen_expr(ctx, value)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let len = ctx
|
||||
.build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len"))
|
||||
.into_int_value();
|
||||
let raw_index = generator
|
||||
.gen_expr(ctx, slice)?
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||
.into_int_value();
|
||||
let raw_index = ctx.builder.build_int_s_extend(
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx),
|
||||
"sext",
|
||||
);
|
||||
// handle negative index
|
||||
let is_negative = ctx.builder.build_int_compare(
|
||||
IntPredicate::SLT,
|
||||
raw_index,
|
||||
generator.get_size_type(ctx.ctx).const_zero(),
|
||||
"is_neg",
|
||||
);
|
||||
let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted");
|
||||
let index = ctx
|
||||
.builder
|
||||
.build_select(is_negative, adjusted, raw_index, "index")
|
||||
.into_int_value();
|
||||
// unsigned less than is enough, because negative index after adjustment is
|
||||
// bigger than the length (for unsigned cmp)
|
||||
let bound_check = ctx.builder.build_int_compare(
|
||||
IntPredicate::ULT,
|
||||
index,
|
||||
len,
|
||||
"inbound",
|
||||
);
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
bound_check,
|
||||
"0:IndexError",
|
||||
"index {0} out of bounds 0:{1}",
|
||||
[Some(raw_index), Some(len), None],
|
||||
slice.location,
|
||||
);
|
||||
unsafe {
|
||||
let arr_ptr = ctx
|
||||
.build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr"))
|
||||
.into_pointer_value();
|
||||
ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or(""))
|
||||
}
|
||||
}
|
||||
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
todo!()
|
||||
}
|
||||
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
|
@ -203,7 +240,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
|||
let value = value
|
||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||
.into_pointer_value();
|
||||
let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
|
||||
let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
|
@ -399,6 +436,80 @@ pub fn gen_for<G: CodeGenerator>(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
/// Generates a C-style `for` construct using lambdas, similar to the following C code:
|
||||
///
|
||||
/// ```c
|
||||
/// for (x... = init(); cond(x...); update(x...)) {
|
||||
/// body(x...);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The
|
||||
/// return value is a [Clone] value which will be passed to the other lambdas.
|
||||
/// * `cond` - A lambda containing IR statements checking whether the loop should continue
|
||||
/// executing. The result value must be an `i1` indicating if the loop should continue.
|
||||
/// * `body` - A lambda containing IR statements within the loop body.
|
||||
/// * `update` - A lambda containing IR statements updating loop variables.
|
||||
pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
init: InitFn,
|
||||
cond: CondFn,
|
||||
body: BodyFn,
|
||||
update: UpdateFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
I: Clone,
|
||||
InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
|
||||
CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
|
||||
BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
|
||||
{
|
||||
let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap();
|
||||
let init_bb = ctx.ctx.append_basic_block(current, "for.init");
|
||||
// The BB containing the loop condition check
|
||||
let cond_bb = ctx.ctx.append_basic_block(current, "for.cond");
|
||||
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
|
||||
// The BB containing the increment expression
|
||||
let update_bb = ctx.ctx.append_basic_block(current, "for.update");
|
||||
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
|
||||
|
||||
// store loop bb information and restore it later
|
||||
let loop_bb = ctx.loop_target.replace((update_bb, cont_bb));
|
||||
|
||||
ctx.builder.build_unconditional_branch(init_bb);
|
||||
|
||||
let loop_var = {
|
||||
ctx.builder.position_at_end(init_bb);
|
||||
let result = init(generator, ctx)?;
|
||||
ctx.builder.build_unconditional_branch(cond_bb);
|
||||
|
||||
result
|
||||
};
|
||||
|
||||
ctx.builder.position_at_end(cond_bb);
|
||||
let cond = cond(generator, ctx, loop_var.clone())?;
|
||||
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
|
||||
ctx.builder.build_conditional_branch(
|
||||
cond,
|
||||
body_bb,
|
||||
cont_bb
|
||||
);
|
||||
|
||||
ctx.builder.position_at_end(body_bb);
|
||||
body(generator, ctx, loop_var.clone())?;
|
||||
ctx.builder.build_unconditional_branch(update_bb);
|
||||
|
||||
ctx.builder.position_at_end(update_bb);
|
||||
update(generator, ctx, loop_var)?;
|
||||
ctx.builder.build_unconditional_branch(cond_bb);
|
||||
|
||||
ctx.builder.position_at_end(cont_bb);
|
||||
ctx.loop_target = loop_bb;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// See [`CodeGenerator::gen_while`].
|
||||
pub fn gen_while<G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
|
|
|
@ -354,13 +354,14 @@ pub trait SymbolResolver {
|
|||
}
|
||||
|
||||
thread_local! {
|
||||
static IDENTIFIER_ID: [StrRef; 11] = [
|
||||
static IDENTIFIER_ID: [StrRef; 12] = [
|
||||
"int32".into(),
|
||||
"int64".into(),
|
||||
"float".into(),
|
||||
"bool".into(),
|
||||
"virtual".into(),
|
||||
"list".into(),
|
||||
"ndarray".into(),
|
||||
"tuple".into(),
|
||||
"str".into(),
|
||||
"Exception".into(),
|
||||
|
@ -385,11 +386,12 @@ pub fn parse_type_annotation<T>(
|
|||
let bool_id = ids[3];
|
||||
let virtual_id = ids[4];
|
||||
let list_id = ids[5];
|
||||
let tuple_id = ids[6];
|
||||
let str_id = ids[7];
|
||||
let exn_id = ids[8];
|
||||
let uint32_id = ids[9];
|
||||
let uint64_id = ids[10];
|
||||
let ndarray_id = ids[6];
|
||||
let tuple_id = ids[7];
|
||||
let str_id = ids[8];
|
||||
let exn_id = ids[9];
|
||||
let uint32_id = ids[10];
|
||||
let uint64_id = ids[11];
|
||||
|
||||
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
||||
if *id == int32_id {
|
||||
|
@ -460,6 +462,21 @@ pub fn parse_type_annotation<T>(
|
|||
} else if *id == list_id {
|
||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
|
||||
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
||||
} else if *id == ndarray_id {
|
||||
let Tuple { elts, .. } = &slice.node else {
|
||||
return Err(HashSet::from([
|
||||
String::from("Expected 2 type arguments for ndarray"),
|
||||
]))
|
||||
};
|
||||
if elts.len() < 2 {
|
||||
return Err(HashSet::from([
|
||||
String::from("Expected 2 type arguments for ndarray"),
|
||||
]))
|
||||
}
|
||||
|
||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?;
|
||||
let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?;
|
||||
Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims }))
|
||||
} else if *id == tuple_id {
|
||||
if let Tuple { elts, .. } = &slice.node {
|
||||
let ty = elts
|
||||
|
|
|
@ -13,14 +13,22 @@ use crate::{
|
|||
stmt::exn_constructor,
|
||||
},
|
||||
symbol_resolver::SymbolValue,
|
||||
toplevel::numpy::{
|
||||
gen_ndarray_empty,
|
||||
gen_ndarray_eye,
|
||||
gen_ndarray_full,
|
||||
gen_ndarray_ones,
|
||||
gen_ndarray_zeros,
|
||||
},
|
||||
};
|
||||
use inkwell::{
|
||||
attributes::{Attribute, AttributeLoc},
|
||||
types::{BasicType, BasicMetadataTypeEnum},
|
||||
values::BasicMetadataValueEnum,
|
||||
values::{BasicValue, BasicMetadataValueEnum},
|
||||
FloatPredicate,
|
||||
IntPredicate
|
||||
};
|
||||
use crate::toplevel::numpy::gen_ndarray_identity;
|
||||
|
||||
type BuiltinInfo = Vec<(Arc<RwLock<TopLevelDef>>, Option<Stmt>)>;
|
||||
|
||||
|
@ -278,6 +286,31 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let boolean = primitives.0.bool;
|
||||
let range = primitives.0.range;
|
||||
let string = primitives.0.str;
|
||||
let ndarray = {
|
||||
let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0);
|
||||
primitives.1.add_ty(ndarray_ty)
|
||||
};
|
||||
let ndarray_float = {
|
||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
|
||||
primitives.1.add_ty(ndarray_ty_enum)
|
||||
};
|
||||
let ndarray_float_2d = {
|
||||
let value = match primitives.0.size_t {
|
||||
64 => SymbolValue::U64(2u64),
|
||||
32 => SymbolValue::U32(2u32),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let ndims = primitives.1.add_ty(TypeEnum::TLiteral {
|
||||
values: vec![value],
|
||||
loc: None,
|
||||
});
|
||||
|
||||
primitives.1.add_ty(TypeEnum::TNDArray {
|
||||
ty: float,
|
||||
ndims,
|
||||
})
|
||||
};
|
||||
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
||||
let num_ty = primitives.1.get_fresh_var_with_range(
|
||||
&[int32, int64, float, boolean, uint32, uint64],
|
||||
Some("N".into()),
|
||||
|
@ -470,6 +503,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
{
|
||||
let tvar = primitives.1.get_fresh_var(Some("T".into()), None);
|
||||
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
||||
|
||||
Arc::new(RwLock::new(TopLevelDef::Class {
|
||||
name: "ndarray".into(),
|
||||
object_id: DefinitionId(14),
|
||||
type_vars: vec![tvar.0, ndims.0],
|
||||
fields: Vec::default(),
|
||||
methods: Vec::default(),
|
||||
ancestors: Vec::default(),
|
||||
constructor: None,
|
||||
resolver: None,
|
||||
loc: None,
|
||||
}))
|
||||
},
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "int32".into(),
|
||||
simple_name: "int32".into(),
|
||||
|
@ -821,6 +870,115 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_ndarray",
|
||||
ndarray_float,
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(list_int32, "shape")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_empty",
|
||||
ndarray_float,
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(list_int32, "shape")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_empty(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_zeros",
|
||||
ndarray_float,
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(list_int32, "shape")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_zeros(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_ones",
|
||||
ndarray_float,
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(list_int32, "shape")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_ones(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
{
|
||||
let tv = primitives.1.get_fresh_var(Some("T".into()), None).0;
|
||||
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_full",
|
||||
ndarray,
|
||||
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
|
||||
// type variable
|
||||
&[(list_int32, "shape"), (tv, "fill_value")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_full(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
)
|
||||
},
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "np_eye".into(),
|
||||
simple_name: "np_eye".into(),
|
||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg { name: "N".into(), ty: int32, default_value: None },
|
||||
// TODO(Derppening): Default values current do not work?
|
||||
FuncArg {
|
||||
name: "M".into(),
|
||||
ty: int32,
|
||||
default_value: Some(SymbolValue::OptionNone)
|
||||
},
|
||||
FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) },
|
||||
],
|
||||
ret: ndarray_float_2d,
|
||||
vars: var_map.clone(),
|
||||
})),
|
||||
var_id: Default::default(),
|
||||
instance_to_symbol: Default::default(),
|
||||
instance_to_stmt: Default::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_eye(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
"np_identity",
|
||||
ndarray_float_2d,
|
||||
&[(int32, "n")],
|
||||
Box::new(|ctx, obj, fun, args, generator| {
|
||||
gen_ndarray_identity(ctx, obj, fun, args, generator)
|
||||
.map(|val| Some(val.as_basic_value_enum()))
|
||||
}),
|
||||
),
|
||||
create_fn_by_codegen(
|
||||
primitives,
|
||||
&var_map,
|
||||
|
@ -1265,10 +1423,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
}),
|
||||
),
|
||||
Arc::new(RwLock::new({
|
||||
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
|
||||
let tvar = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 });
|
||||
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
||||
let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 });
|
||||
|
||||
let arg_ty = primitives.1.get_fresh_var_with_range(
|
||||
&[list, primitives.0.range],
|
||||
&[list, ndarray, primitives.0.range],
|
||||
Some("I".into()),
|
||||
None,
|
||||
);
|
||||
|
@ -1278,7 +1439,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }],
|
||||
ret: int32,
|
||||
vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)]
|
||||
vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
})),
|
||||
|
@ -1296,19 +1457,40 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
let (start, end, step) = destructure_range(ctx, arg);
|
||||
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
|
||||
} else {
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let len = ctx
|
||||
.build_gep_and_load(
|
||||
arg.into_pointer_value(),
|
||||
&[zero, int32.const_int(1, false)],
|
||||
None,
|
||||
)
|
||||
.into_int_value();
|
||||
if len.get_type().get_bit_width() == 32 {
|
||||
Some(len.into())
|
||||
} else {
|
||||
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
||||
match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
||||
TypeEnum::TList { .. } => {
|
||||
let int32 = ctx.ctx.i32_type();
|
||||
let zero = int32.const_zero();
|
||||
let len = ctx
|
||||
.build_gep_and_load(
|
||||
arg.into_pointer_value(),
|
||||
&[zero, int32.const_int(1, false)],
|
||||
None,
|
||||
)
|
||||
.into_int_value();
|
||||
if len.get_type().get_bit_width() == 32 {
|
||||
Some(len.into())
|
||||
} else {
|
||||
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
||||
}
|
||||
}
|
||||
TypeEnum::TNDArray { .. } => {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let i32_zero = llvm_i32.const_zero();
|
||||
|
||||
let len = ctx.build_gep_and_load(
|
||||
arg.into_pointer_value(),
|
||||
&[i32_zero, i32_zero],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
if len.get_type().get_bit_width() != 32 {
|
||||
Some(ctx.builder.build_int_truncate(len, llvm_i32, "len").into())
|
||||
} else {
|
||||
Some(len.into())
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
})
|
||||
},
|
||||
|
|
|
@ -37,12 +37,8 @@ pub struct TopLevelComposer {
|
|||
// number of built-in function and classes in the definition list, later skip
|
||||
pub builtin_num: usize,
|
||||
pub core_config: ComposerConfig,
|
||||
}
|
||||
|
||||
impl Default for TopLevelComposer {
|
||||
fn default() -> Self {
|
||||
Self::new(vec![], ComposerConfig::default()).0
|
||||
}
|
||||
/// The size of a native word on the target platform.
|
||||
pub size_t: u32,
|
||||
}
|
||||
|
||||
impl TopLevelComposer {
|
||||
|
@ -52,8 +48,9 @@ impl TopLevelComposer {
|
|||
pub fn new(
|
||||
builtins: Vec<(StrRef, FunSignature, Arc<GenCall>)>,
|
||||
core_config: ComposerConfig,
|
||||
size_t: u32,
|
||||
) -> (Self, HashMap<StrRef, DefinitionId>, HashMap<StrRef, Type>) {
|
||||
let mut primitives = Self::make_primitives();
|
||||
let mut primitives = Self::make_primitives(size_t);
|
||||
let mut definition_ast_list = builtins::get_builtins(&mut primitives);
|
||||
let primitives_ty = primitives.0;
|
||||
let mut unifier = primitives.1;
|
||||
|
@ -146,6 +143,7 @@ impl TopLevelComposer {
|
|||
defined_names,
|
||||
method_class,
|
||||
core_config,
|
||||
size_t,
|
||||
},
|
||||
builtin_id,
|
||||
builtin_ty,
|
||||
|
|
|
@ -44,7 +44,7 @@ impl TopLevelDef {
|
|||
|
||||
impl TopLevelComposer {
|
||||
#[must_use]
|
||||
pub fn make_primitives() -> (PrimitiveStore, Unifier) {
|
||||
pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
|
||||
let mut unifier = Unifier::new();
|
||||
let int32 = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: DefinitionId(0),
|
||||
|
@ -144,6 +144,7 @@ impl TopLevelComposer {
|
|||
str,
|
||||
exception,
|
||||
option,
|
||||
size_t,
|
||||
};
|
||||
unifier.put_primitive_store(&primitives);
|
||||
crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier);
|
||||
|
|
|
@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize);
|
|||
pub mod builtins;
|
||||
pub mod composer;
|
||||
pub mod helper;
|
||||
pub mod numpy;
|
||||
pub mod type_annotation;
|
||||
use composer::*;
|
||||
use type_annotation::*;
|
||||
|
|
|
@ -0,0 +1,883 @@
|
|||
use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
|
||||
use inkwell::values::{ArrayValue, IntValue};
|
||||
use nac3parser::ast::StrRef;
|
||||
use crate::{
|
||||
codegen::{
|
||||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{
|
||||
call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
call_ndarray_init_dims,
|
||||
},
|
||||
stmt::gen_for_callback
|
||||
},
|
||||
symbol_resolver::ValueEnum,
|
||||
toplevel::DefinitionId,
|
||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||
};
|
||||
|
||||
/// Creates an `NDArray` instance from a constant shape.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue].
|
||||
fn create_ndarray_const_shape<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: ArrayValue<'ctx>
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
for i in 0..shape.get_type().len() {
|
||||
let shape_dim = ctx.builder.build_extract_value(
|
||||
shape,
|
||||
i,
|
||||
"",
|
||||
).unwrap();
|
||||
|
||||
let shape_dim_gez = ctx.builder.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
shape_dim.into_int_value(),
|
||||
llvm_usize.const_zero(),
|
||||
""
|
||||
);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
shape_dim_gez,
|
||||
"0:ValueError",
|
||||
"negative dimensions not supported",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false);
|
||||
|
||||
let ndarray_num_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||
|
||||
let ndarray_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
ctx.builder.build_store(
|
||||
ndarray_dims,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
ndarray_num_dims,
|
||||
"",
|
||||
),
|
||||
);
|
||||
|
||||
for i in 0..shape.get_type().len() {
|
||||
let ndarray_dim = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_pointer_value();
|
||||
let ndarray_dim = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray_dim,
|
||||
&[llvm_i32.const_int(i as u64, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
let shape_dim = ctx.builder.build_extract_value(shape, i, "")
|
||||
.map(|val| val.into_int_value())
|
||||
.unwrap();
|
||||
|
||||
ctx.builder.build_store(ndarray_dim, shape_dim);
|
||||
}
|
||||
|
||||
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||
);
|
||||
|
||||
let ndarray_data = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(
|
||||
ndarray_data,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_ndarray_data_t,
|
||||
ndarray_num_elems,
|
||||
""
|
||||
),
|
||||
);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
fn ndarray_zero_value<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||
ctx.ctx.i32_type().const_zero().into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||
ctx.ctx.i64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_zero().into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "").into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
fn ndarray_one_value<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
) -> BasicValueEnum<'ctx> {
|
||||
if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32);
|
||||
ctx.ctx.i32_type().const_int(1, is_signed).into()
|
||||
} else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) {
|
||||
let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64);
|
||||
ctx.ctx.i64_type().const_int(1, is_signed).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) {
|
||||
ctx.ctx.f64_type().const_float(1.0).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) {
|
||||
ctx.ctx.bool_type().const_int(1, false).into()
|
||||
} else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) {
|
||||
ctx.gen_string(generator, "1").into()
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the NDArray.
|
||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||
fn call_ndarray_empty_impl<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: PointerValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
||||
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum();
|
||||
assert!(llvm_ndarray_data_t.is_sized());
|
||||
|
||||
// Assert that all dimensions are non-negative
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let shape_len = ctx.build_gep_and_load(
|
||||
shape,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, ""))
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let shape_elems = ctx.build_gep_and_load(
|
||||
shape,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None
|
||||
).into_pointer_value();
|
||||
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let shape_dim = ctx.build_gep_and_load(
|
||||
shape_elems,
|
||||
&[i],
|
||||
None
|
||||
).into_int_value();
|
||||
|
||||
let shape_dim_gez = ctx.builder.build_int_compare(
|
||||
IntPredicate::SGE,
|
||||
shape_dim,
|
||||
llvm_i32.const_zero(),
|
||||
""
|
||||
);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
shape_dim_gez,
|
||||
"0:ValueError",
|
||||
"negative dimensions not supported",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
|
||||
ctx.builder.build_store(i_addr, i);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)?;
|
||||
|
||||
let ndarray = generator.gen_var_alloc(
|
||||
ctx,
|
||||
llvm_ndarray_t.into(),
|
||||
None,
|
||||
)?;
|
||||
|
||||
let num_dims = ctx.build_gep_and_load(
|
||||
shape,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
None
|
||||
).into_int_value();
|
||||
|
||||
let ndarray_num_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(ndarray_num_dims, num_dims);
|
||||
|
||||
let ndarray_dims = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_dims = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
ctx.builder.build_store(
|
||||
ndarray_dims,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_usize,
|
||||
ndarray_num_dims,
|
||||
"",
|
||||
),
|
||||
);
|
||||
|
||||
call_ndarray_init_dims(generator, ctx, ndarray, shape);
|
||||
|
||||
let (ndarray_num_dims, ndarray_dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(ndarray_num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(ndarray_dims, "").into_pointer_value(),
|
||||
);
|
||||
|
||||
let ndarray_data = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
"",
|
||||
)
|
||||
};
|
||||
ctx.builder.build_store(
|
||||
ndarray_data,
|
||||
ctx.builder.build_array_alloca(
|
||||
llvm_ndarray_data_t,
|
||||
ndarray_num_elems,
|
||||
"",
|
||||
),
|
||||
);
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
|
||||
/// its input.
|
||||
///
|
||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||
/// with the given value (as opposed to all elements within the array).
|
||||
fn ndarray_fill_flattened<'ctx, 'a, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (num_dims, dims) = unsafe {
|
||||
(
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||
""
|
||||
),
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
|
||||
""
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let ndarray_num_elems = call_ndarray_calc_size(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.builder.build_load(num_dims, "").into_int_value(),
|
||||
ctx.builder.build_load(dims, "").into_pointer_value(),
|
||||
);
|
||||
|
||||
gen_for_callback(
|
||||
generator,
|
||||
ctx,
|
||||
|generator, ctx| {
|
||||
let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?;
|
||||
ctx.builder.build_store(i, llvm_usize.const_zero());
|
||||
|
||||
Ok(i)
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
|
||||
Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, ""))
|
||||
},
|
||||
|generator, ctx, i_addr| {
|
||||
let ndarray_data = ctx.build_gep_and_load(
|
||||
ndarray,
|
||||
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
|
||||
None
|
||||
).into_pointer_value();
|
||||
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let elem = unsafe {
|
||||
ctx.builder.build_in_bounds_gep(
|
||||
ndarray_data,
|
||||
&[i],
|
||||
""
|
||||
)
|
||||
};
|
||||
|
||||
let value = value_fn(generator, ctx, i)?;
|
||||
ctx.builder.build_store(elem, value);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
|_, ctx, i_addr| {
|
||||
let i = ctx.builder
|
||||
.build_load(i_addr, "")
|
||||
.into_int_value();
|
||||
let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), "");
|
||||
ctx.builder.build_store(i_addr, i);
|
||||
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices
|
||||
/// as its input
|
||||
///
|
||||
/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements
|
||||
/// with the given value (as opposed to all elements within the array).
|
||||
fn ndarray_fill_indexed<'ctx, 'a, ValueFn>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
ndarray: PointerValue<'ctx>,
|
||||
value_fn: ValueFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, idx| {
|
||||
let indices = call_ndarray_calc_nd_indices(
|
||||
generator,
|
||||
ctx,
|
||||
idx,
|
||||
ndarray,
|
||||
)?;
|
||||
|
||||
value_fn(generator, ctx, indices)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the NDArray.
|
||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||
fn call_ndarray_zeros_impl<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: PointerValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.str,
|
||||
];
|
||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, _| {
|
||||
let value = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the NDArray.
|
||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||
fn call_ndarray_ones_impl<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: PointerValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let supported_types = [
|
||||
ctx.primitives.int32,
|
||||
ctx.primitives.int64,
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
ctx.primitives.float,
|
||||
ctx.primitives.bool,
|
||||
ctx.primitives.str,
|
||||
];
|
||||
assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty)));
|
||||
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, _| {
|
||||
let value = ndarray_one_value(generator, ctx, elem_ty);
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.ones`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the NDArray.
|
||||
/// * `shape` - The `shape` parameter used to construct the NDArray.
|
||||
fn call_ndarray_full_impl<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
shape: PointerValue<'ctx>,
|
||||
fill_value: BasicValueEnum<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, _| {
|
||||
let value = if fill_value.is_pointer_value() {
|
||||
let llvm_void = ctx.ctx.void_type();
|
||||
let llvm_i1 = ctx.ctx.bool_type();
|
||||
let llvm_i8 = ctx.ctx.i8_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default());
|
||||
|
||||
let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?;
|
||||
|
||||
let memcpy_fn_name = format!(
|
||||
"llvm.memcpy.p0i8.p0i8.i{}",
|
||||
generator.get_size_type(ctx.ctx).get_bit_width(),
|
||||
);
|
||||
let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| {
|
||||
let fn_type = llvm_void.fn_type(
|
||||
&[
|
||||
llvm_pi8.into(),
|
||||
llvm_pi8.into(),
|
||||
llvm_usize.into(),
|
||||
llvm_i1.into(),
|
||||
],
|
||||
false,
|
||||
);
|
||||
|
||||
ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None)
|
||||
});
|
||||
|
||||
ctx.builder.build_call(
|
||||
memcpy_fn,
|
||||
&[
|
||||
copy.into(),
|
||||
fill_value.into(),
|
||||
fill_value.get_type().size_of().unwrap().into(),
|
||||
llvm_i1.const_zero().into(),
|
||||
],
|
||||
"",
|
||||
);
|
||||
|
||||
copy.into()
|
||||
} else if fill_value.is_int_value() || fill_value.is_float_value() {
|
||||
fill_value.into()
|
||||
} else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.eye`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the NDArray.
|
||||
fn call_ndarray_eye_impl<'ctx, 'a>(
|
||||
generator: &mut dyn CodeGenerator,
|
||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||
elem_ty: Type,
|
||||
nrows: IntValue<'ctx>,
|
||||
ncols: IntValue<'ctx>,
|
||||
offset: IntValue<'ctx>,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
let llvm_i32 = ctx.ctx.i32_type();
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let llvm_usize_2 = llvm_usize.array_type(2);
|
||||
|
||||
let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?;
|
||||
|
||||
let shape = ctx.builder.build_load(shape_addr, "")
|
||||
.into_array_value();
|
||||
|
||||
let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, "");
|
||||
let shape = ctx.builder
|
||||
.build_insert_value(shape, nrows, 0, "")
|
||||
.map(|val| val.into_array_value())
|
||||
.unwrap();
|
||||
|
||||
let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, "");
|
||||
let shape = ctx.builder
|
||||
.build_insert_value(shape, ncols, 1, "")
|
||||
.map(|val| val.into_array_value())
|
||||
.unwrap();
|
||||
|
||||
let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?;
|
||||
|
||||
ndarray_fill_indexed(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
|generator, ctx, indices| {
|
||||
let row = ctx.build_gep_and_load(
|
||||
indices,
|
||||
&[llvm_i32.const_zero()],
|
||||
None,
|
||||
).into_int_value();
|
||||
let col = ctx.build_gep_and_load(
|
||||
indices,
|
||||
&[llvm_i32.const_int(1, true)],
|
||||
None,
|
||||
).into_int_value();
|
||||
|
||||
let col_with_offset = ctx.builder.build_int_add(
|
||||
col,
|
||||
ctx.builder.build_int_z_extend_or_bit_cast(offset, llvm_usize, ""),
|
||||
""
|
||||
);
|
||||
let is_on_diag = ctx.builder.build_int_compare(
|
||||
IntPredicate::EQ,
|
||||
row,
|
||||
col_with_offset,
|
||||
""
|
||||
);
|
||||
|
||||
let zero = ndarray_zero_value(generator, ctx, elem_ty);
|
||||
let one = ndarray_one_value(generator, ctx, elem_ty);
|
||||
|
||||
let value = ctx.builder.build_select(is_on_diag, one, zero, "");
|
||||
|
||||
Ok(value)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_empty_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
shape_arg.into_pointer_value(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.zeros`.
|
||||
pub fn gen_ndarray_zeros<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_zeros_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
shape_arg.into_pointer_value(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.ones`.
|
||||
pub fn gen_ndarray_ones<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||
|
||||
call_ndarray_ones_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
shape_arg.into_pointer_value(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.full`.
|
||||
pub fn gen_ndarray_full<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 2);
|
||||
|
||||
let shape_ty = fun.0.args[0].ty;
|
||||
let shape_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, shape_ty)?;
|
||||
let fill_value_ty = fun.0.args[1].ty;
|
||||
let fill_value_arg = args[1].1.clone()
|
||||
.to_basic_value_enum(context, generator, fill_value_ty)?;
|
||||
|
||||
call_ndarray_full_impl(
|
||||
generator,
|
||||
context,
|
||||
fill_value_ty,
|
||||
shape_arg.into_pointer_value(),
|
||||
fill_value_arg,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.eye`.
|
||||
pub fn gen_ndarray_eye<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert!(matches!(args.len(), 1..=3));
|
||||
|
||||
let nrows_ty = fun.0.args[0].ty;
|
||||
let nrows_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, nrows_ty)?;
|
||||
|
||||
let ncols_ty = fun.0.args[1].ty;
|
||||
let ncols_arg = args.iter()
|
||||
.find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false))
|
||||
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty))
|
||||
.unwrap_or_else(|| {
|
||||
args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty)
|
||||
})?;
|
||||
|
||||
let offset_ty = fun.0.args[2].ty;
|
||||
let offset_arg = args.iter()
|
||||
.find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false))
|
||||
.map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(context.gen_symbol_val(
|
||||
generator,
|
||||
fun.0.args[2].default_value.as_ref().unwrap(),
|
||||
offset_ty
|
||||
))
|
||||
})?;
|
||||
|
||||
call_ndarray_eye_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
nrows_arg.into_int_value(),
|
||||
ncols_arg.into_int_value(),
|
||||
offset_arg.into_int_value(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.identity`.
|
||||
pub fn gen_ndarray_identity<'ctx, 'a>(
|
||||
context: &mut CodeGenContext<'ctx, 'a>,
|
||||
obj: Option<(Type, ValueEnum<'ctx>)>,
|
||||
fun: (&FunSignature, DefinitionId),
|
||||
args: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
|
||||
generator: &mut dyn CodeGenerator,
|
||||
) -> Result<PointerValue<'ctx>, String> {
|
||||
assert!(obj.is_none());
|
||||
assert_eq!(args.len(), 1);
|
||||
|
||||
let llvm_usize = generator.get_size_type(context.ctx);
|
||||
|
||||
let n_ty = fun.0.args[0].ty;
|
||||
let n_arg = args[0].1.clone()
|
||||
.to_basic_value_enum(context, generator, n_ty)?;
|
||||
|
||||
call_ndarray_eye_impl(
|
||||
generator,
|
||||
context,
|
||||
context.primitives.float,
|
||||
n_arg.into_int_value(),
|
||||
n_arg.into_int_value(),
|
||||
llvm_usize.const_zero(),
|
||||
)
|
||||
}
|
|
@ -491,11 +491,24 @@ pub fn get_type_from_type_annotation_kinds(
|
|||
(*name, (subst_ty, *mutability))
|
||||
}));
|
||||
let need_subst = !subst.is_empty();
|
||||
let ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: tobj_fields,
|
||||
params: subst,
|
||||
});
|
||||
let ty = if obj_id == &DefinitionId(14) {
|
||||
assert_eq!(subst.len(), 2);
|
||||
let tv_tys = subst.iter()
|
||||
.sorted_by_key(|(k, _)| *k)
|
||||
.map(|(_, v)| v)
|
||||
.collect_vec();
|
||||
|
||||
unifier.add_ty(TypeEnum::TNDArray {
|
||||
ty: *tv_tys[0],
|
||||
ndims: *tv_tys[1],
|
||||
})
|
||||
} else {
|
||||
unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
fields: tobj_fields,
|
||||
params: subst,
|
||||
})
|
||||
};
|
||||
if need_subst {
|
||||
if let Some(wl) = subst_list.as_mut() {
|
||||
wl.push(ty);
|
||||
|
|
|
@ -5,12 +5,18 @@ use std::{cell::RefCell, sync::Arc};
|
|||
|
||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
|
||||
use super::{magic_methods::*, typedef::CallId};
|
||||
use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext};
|
||||
use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext};
|
||||
use itertools::izip;
|
||||
use nac3parser::ast::{
|
||||
self,
|
||||
fold::{self, Fold},
|
||||
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
|
||||
Arguments,
|
||||
Comprehension,
|
||||
ExprContext,
|
||||
ExprKind,
|
||||
Located,
|
||||
Location,
|
||||
StrRef
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -41,6 +47,18 @@ pub struct PrimitiveStore {
|
|||
pub str: Type,
|
||||
pub exception: Type,
|
||||
pub option: Type,
|
||||
pub size_t: u32,
|
||||
}
|
||||
|
||||
impl PrimitiveStore {
|
||||
/// Returns a [Type] representing `size_t`.
|
||||
pub fn usize(&self) -> Type {
|
||||
match self.size_t {
|
||||
32 => self.uint32,
|
||||
64 => self.uint64,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct FunctionData {
|
||||
|
@ -205,8 +223,12 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||
if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) {
|
||||
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
|
||||
} else {
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
||||
self.unify(list, iter.custom.unwrap(), &iter.location)?;
|
||||
let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }),
|
||||
TypeEnum::TNDArray { .. } => todo!(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?;
|
||||
}
|
||||
let body =
|
||||
body.into_iter().map(|b| self.fold_stmt(b)).collect::<Result<Vec<_>, _>>()?;
|
||||
|
@ -761,6 +783,228 @@ impl<'a> Inferencer<'a> {
|
|||
})
|
||||
}
|
||||
|
||||
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
|
||||
/// returns [`None`].
|
||||
fn try_fold_special_call(
|
||||
&mut self,
|
||||
location: Location,
|
||||
func: &ast::Expr<()>,
|
||||
args: &mut Vec<ast::Expr<()>>,
|
||||
keywords: &Vec<Located<ast::KeywordData>>,
|
||||
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
|
||||
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
// handle special functions that cannot be typed in the usual way...
|
||||
if id == &"virtual".into() {
|
||||
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
|
||||
return report_error(
|
||||
"`virtual` can only accept 1/2 positional arguments",
|
||||
*func_location,
|
||||
)
|
||||
}
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ty = if let Some(arg) = args.pop() {
|
||||
let top_level_defs = self.top_level.definitions.read();
|
||||
self.function_data.resolver.parse_type_annotation(
|
||||
top_level_defs.as_slice(),
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
&arg,
|
||||
)?
|
||||
} else {
|
||||
self.unifier.get_dummy_var().0
|
||||
};
|
||||
self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location));
|
||||
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: None,
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
// int64 is special because its argument can be a constant larger than int32
|
||||
if id == &"int64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.int64);
|
||||
let v: Result<i64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id == &"uint32".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint32);
|
||||
let v: Result<u32, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id == &"uint64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint64);
|
||||
let v: Result<u64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Some(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
}))
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1-argument ndarray n-dimensional creation functions
|
||||
if [
|
||||
"np_ndarray".into(),
|
||||
"np_empty".into(),
|
||||
"np_zeros".into(),
|
||||
"np_ones".into(),
|
||||
].contains(id) && args.len() == 1 {
|
||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||
return report_error(
|
||||
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
|
||||
args[0].location
|
||||
)
|
||||
};
|
||||
|
||||
let ndims = elts.len() as u64;
|
||||
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ndims = self.unifier.get_fresh_literal(
|
||||
vec![SymbolValue::U64(ndims)],
|
||||
None,
|
||||
);
|
||||
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||
ty: self.primitives.float,
|
||||
ndims
|
||||
});
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "shape".into(),
|
||||
ty: arg0.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: HashMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// 2-argument ndarray n-dimensional creation functions
|
||||
if id == &"np_full".into() && args.len() == 2 {
|
||||
let ExprKind::List { elts, .. } = &args[0].node else {
|
||||
return report_error(
|
||||
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
|
||||
args[0].location
|
||||
)
|
||||
};
|
||||
|
||||
let ndims = elts.len() as u64;
|
||||
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let arg1 = self.fold_expr(args.remove(0))?;
|
||||
|
||||
let ty = arg1.custom.unwrap();
|
||||
let ndims = self.unifier.get_fresh_literal(
|
||||
vec![SymbolValue::U64(ndims)],
|
||||
None,
|
||||
);
|
||||
|
||||
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
||||
ty,
|
||||
ndims
|
||||
});
|
||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "shape".into(),
|
||||
ty: arg0.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
FuncArg {
|
||||
name: "fill_value".into(),
|
||||
ty: arg1.custom.unwrap(),
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret,
|
||||
vars: HashMap::new(),
|
||||
}));
|
||||
|
||||
return Ok(Some(Located {
|
||||
location,
|
||||
custom: Some(ret),
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: Some(custom),
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||
}),
|
||||
args: vec![arg0, arg1],
|
||||
keywords: vec![],
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn fold_call(
|
||||
&mut self,
|
||||
location: Location,
|
||||
|
@ -768,111 +1012,11 @@ impl<'a> Inferencer<'a> {
|
|||
mut args: Vec<ast::Expr<()>>,
|
||||
keywords: Vec<Located<ast::KeywordData>>,
|
||||
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
|
||||
let func =
|
||||
if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } =
|
||||
func
|
||||
{
|
||||
// handle special functions that cannot be typed in the usual way...
|
||||
if id == "virtual".into() {
|
||||
if args.is_empty() || args.len() > 2 || !keywords.is_empty() {
|
||||
return report_error(
|
||||
"`virtual` can only accept 1/2 positional arguments",
|
||||
func_location,
|
||||
);
|
||||
}
|
||||
let arg0 = self.fold_expr(args.remove(0))?;
|
||||
let ty = if let Some(arg) = args.pop() {
|
||||
let top_level_defs = self.top_level.definitions.read();
|
||||
self.function_data.resolver.parse_type_annotation(
|
||||
top_level_defs.as_slice(),
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
&arg,
|
||||
)?
|
||||
} else {
|
||||
self.unifier.get_dummy_var().0
|
||||
};
|
||||
self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location));
|
||||
let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty }));
|
||||
return Ok(Located {
|
||||
location,
|
||||
custom,
|
||||
node: ExprKind::Call {
|
||||
func: Box::new(Located {
|
||||
custom: None,
|
||||
location: func.location,
|
||||
node: ExprKind::Name { id, ctx },
|
||||
}),
|
||||
args: vec![arg0],
|
||||
keywords: vec![],
|
||||
},
|
||||
});
|
||||
}
|
||||
// int64 is special because its argument can be a constant larger than int32
|
||||
if id == "int64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.int64);
|
||||
let v: Result<i64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id == "uint32".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint32);
|
||||
let v: Result<u32, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
if id == "uint64".into() && args.len() == 1 {
|
||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||
&args[0].node
|
||||
{
|
||||
let custom = Some(self.primitives.uint64);
|
||||
let v: Result<u64, _> = (*val).try_into();
|
||||
return if v.is_ok() {
|
||||
Ok(Located {
|
||||
location: args[0].location,
|
||||
custom,
|
||||
node: ExprKind::Constant {
|
||||
value: ast::Constant::Int(*val),
|
||||
kind: kind.clone(),
|
||||
},
|
||||
})
|
||||
} else {
|
||||
report_error("Integer out of bound", args[0].location)
|
||||
}
|
||||
}
|
||||
}
|
||||
Located { location: func_location, custom, node: ExprKind::Name { id, ctx } }
|
||||
} else {
|
||||
func
|
||||
};
|
||||
let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
|
||||
return Ok(spec_call_func)
|
||||
} else {
|
||||
func
|
||||
};
|
||||
let func = Box::new(self.fold_expr(func)?);
|
||||
let args = args.into_iter().map(|v| self.fold_expr(v)).collect::<Result<Vec<_>, _>>()?;
|
||||
let keywords = keywords
|
||||
|
@ -1105,9 +1249,13 @@ impl<'a> Inferencer<'a> {
|
|||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
||||
}
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||
Ok(list)
|
||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||
TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }),
|
||||
_ => unreachable!()
|
||||
};
|
||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||
Ok(list_like_ty)
|
||||
}
|
||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||
// the index is a constant, so value can be a sequence.
|
||||
|
@ -1127,10 +1275,15 @@ impl<'a> Inferencer<'a> {
|
|||
{
|
||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||
}
|
||||
|
||||
// the index is not a constant, so value can only be a list
|
||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||
TypeEnum::TNDArray { .. } => todo!(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||
Ok(ty)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -159,6 +159,11 @@ pub enum TypeEnum {
|
|||
ty: Type,
|
||||
},
|
||||
|
||||
TNDArray {
|
||||
ty: Type,
|
||||
ndims: Type,
|
||||
},
|
||||
|
||||
/// An object type.
|
||||
TObj {
|
||||
/// The [DefintionId] of this object type.
|
||||
|
@ -193,12 +198,34 @@ impl TypeEnum {
|
|||
TypeEnum::TLiteral { .. } => "TConstant",
|
||||
TypeEnum::TTuple { .. } => "TTuple",
|
||||
TypeEnum::TList { .. } => "TList",
|
||||
TypeEnum::TNDArray { .. } => "TNDArray",
|
||||
TypeEnum::TObj { .. } => "TObj",
|
||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||
TypeEnum::TCall { .. } => "TCall",
|
||||
TypeEnum::TFunc { .. } => "TFunc",
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a [TypeEnum] representing a generic `ndarray` type.
|
||||
///
|
||||
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
|
||||
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
|
||||
#[must_use]
|
||||
pub fn ndarray(
|
||||
unifier: &mut Unifier,
|
||||
dtype: Option<Type>,
|
||||
ndims: Option<Type>,
|
||||
primitives: &PrimitiveStore
|
||||
) -> TypeEnum {
|
||||
let dtype = dtype.unwrap_or_else(|| unifier.get_fresh_var(Some("T".into()), None).0);
|
||||
let ndims = ndims
|
||||
.unwrap_or_else(|| unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None).0);
|
||||
|
||||
TypeEnum::TNDArray {
|
||||
ty: dtype,
|
||||
ndims,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||
|
@ -418,6 +445,9 @@ impl Unifier {
|
|||
TypeEnum::TList { ty } => self
|
||||
.get_instantiations(*ty)
|
||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||
TypeEnum::TNDArray { ty, ndims } => self
|
||||
.get_instantiations(*ty)
|
||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()),
|
||||
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||
}),
|
||||
|
@ -470,6 +500,7 @@ impl Unifier {
|
|||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||
TCall { .. } => false,
|
||||
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||
TObj { params: vars, .. } => {
|
||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||
|
@ -717,7 +748,8 @@ impl Unifier {
|
|||
self.unify_impl(x, b, false)?;
|
||||
self.set_a_to_b(a, x);
|
||||
}
|
||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
|
||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) |
|
||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => {
|
||||
for (k, v) in fields {
|
||||
match *k {
|
||||
RecordKey::Int(_) => {
|
||||
|
@ -789,7 +821,23 @@ impl Unifier {
|
|||
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
|
||||
for (v1, v2) in zip(val1, val2) {
|
||||
if v1 != v2 {
|
||||
return self.incompatible_types(a, b)
|
||||
let symbol_value_to_int = |value: &SymbolValue| -> Option<i128> {
|
||||
match value {
|
||||
SymbolValue::I32(v) => Some(*v as i128),
|
||||
SymbolValue::I64(v) => Some(*v as i128),
|
||||
SymbolValue::U32(v) => Some(*v as i128),
|
||||
SymbolValue::U64(v) => Some(*v as i128),
|
||||
_ => None,
|
||||
}
|
||||
};
|
||||
|
||||
// Try performing integer promotion on literals
|
||||
let v1i = symbol_value_to_int(v1);
|
||||
let v2i = symbol_value_to_int(v2);
|
||||
|
||||
if v1i != v2i {
|
||||
return self.incompatible_types(a, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -813,6 +861,15 @@ impl Unifier {
|
|||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
||||
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
||||
return self.incompatible_types(a, b)
|
||||
}
|
||||
if self.unify_impl(*ndims1, *ndims2, false).is_err() {
|
||||
return self.incompatible_types(a, b)
|
||||
}
|
||||
self.set_a_to_b(a, b);
|
||||
}
|
||||
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
||||
for (k, field) in map {
|
||||
match *k {
|
||||
|
@ -1060,6 +1117,13 @@ impl Unifier {
|
|||
TypeEnum::TList { ty } => {
|
||||
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
||||
}
|
||||
TypeEnum::TNDArray { ty, ndims } => {
|
||||
format!(
|
||||
"ndarray[{}, {}]",
|
||||
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
|
||||
self.internal_stringify(*ndims, obj_to_name, var_to_name, notes),
|
||||
)
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => {
|
||||
format!(
|
||||
"virtual[{}]",
|
||||
|
@ -1179,7 +1243,7 @@ impl Unifier {
|
|||
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||
// should be safe to not implement the substitution for those variants.
|
||||
match &*ty {
|
||||
TypeEnum::TRigidVar { .. } => None,
|
||||
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
|
||||
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
|
||||
TypeEnum::TTuple { ty } => {
|
||||
let mut new_ty = Cow::from(ty);
|
||||
|
@ -1197,6 +1261,19 @@ impl Unifier {
|
|||
TypeEnum::TList { ty } => {
|
||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||
}
|
||||
TypeEnum::TNDArray { ty, ndims } => {
|
||||
let new_ty = self.subst_impl(*ty, mapping, cache);
|
||||
let new_ndims = self.subst_impl(*ndims, mapping, cache);
|
||||
|
||||
if new_ty.is_some() || new_ndims.is_some() {
|
||||
Some(self.add_ty(TypeEnum::TNDArray {
|
||||
ty: new_ty.unwrap_or(*ty),
|
||||
ndims: new_ndims.unwrap_or(*ndims)
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
TypeEnum::TVirtual { ty } => self
|
||||
.subst_impl(*ty, mapping, cache)
|
||||
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
||||
|
@ -1367,6 +1444,19 @@ impl Unifier {
|
|||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
||||
}
|
||||
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
||||
let ty = self.get_intersection(*ty1, *ty2)?;
|
||||
let ndims = self.get_intersection(*ndims1, *ndims2)?;
|
||||
|
||||
Ok(if ty.is_some() || ndims.is_some() {
|
||||
Some(self.add_ty(TNDArray {
|
||||
ty: ty.unwrap_or(*ty1),
|
||||
ndims: ndims.unwrap_or(*ndims1),
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@ impl Unifier {
|
|||
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
||||
}
|
||||
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
||||
| (TypeEnum::TNDArray { ty: ty1 }, TypeEnum::TNDArray { ty: ty2 })
|
||||
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
||||
self.eq(*ty1, *ty2)
|
||||
}
|
||||
|
|
|
@ -94,13 +94,13 @@ uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) {
|
|||
}
|
||||
|
||||
uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||
printf("__nac3_personality(state: %u, exception_object: %u, context: %u\n", state, exception_object, context);
|
||||
printf("__nac3_personality(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
|
||||
exit(101);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) {
|
||||
printf("__nac3_raise(state: %u, exception_object: %u, context: %u\n", state, exception_object, context);
|
||||
printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context);
|
||||
exit(101);
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
|
|
@ -5,11 +5,12 @@ import importlib.util
|
|||
import importlib.machinery
|
||||
import math
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pathlib
|
||||
|
||||
from numpy import int32, int64, uint32, uint64
|
||||
from scipy import special
|
||||
from typing import TypeVar, Generic, Literal
|
||||
from typing import TypeVar, Generic, Literal, Union
|
||||
|
||||
T = TypeVar('T')
|
||||
class Option(Generic[T]):
|
||||
|
@ -50,6 +51,13 @@ class _ConstGenericMarker:
|
|||
def ConstGeneric(name, constraint):
|
||||
return TypeVar(name, _ConstGenericMarker, constraint)
|
||||
|
||||
N = TypeVar("N", bound=np.uint64)
|
||||
class _NDArrayDummy(Generic[T, N]):
|
||||
pass
|
||||
|
||||
# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic
|
||||
NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]]
|
||||
|
||||
def round_away_zero(x):
|
||||
if x >= 0.0:
|
||||
return math.floor(x + 0.5)
|
||||
|
@ -124,6 +132,16 @@ def patch(module):
|
|||
module.ceil64 = math.ceil
|
||||
module.np_ceil = np.ceil
|
||||
|
||||
# NumPy ndarray functions
|
||||
module.ndarray = NDArray
|
||||
module.np_ndarray = np.ndarray
|
||||
module.np_empty = np.empty
|
||||
module.np_zeros = np.zeros
|
||||
module.np_ones = np.ones
|
||||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
|
||||
# NumPy Math functions
|
||||
module.np_isnan = np.isnan
|
||||
module.np_isinf = np.isinf
|
||||
|
@ -166,6 +184,14 @@ def patch(module):
|
|||
module.sp_spec_j0 = special.j0
|
||||
module.sp_spec_j1 = special.j1
|
||||
|
||||
# NumPy NDArray Functions
|
||||
module.np_ndarray = np.ndarray
|
||||
module.np_empty = np.empty
|
||||
module.np_zeros = np.zeros
|
||||
module.np_ones = np.ones
|
||||
module.np_full = np.full
|
||||
module.np_eye = np.eye
|
||||
module.np_identity = np.identity
|
||||
|
||||
def file_import(filename, prefix="file_import_"):
|
||||
filename = pathlib.Path(filename)
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||
pass
|
||||
|
||||
def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]):
|
||||
pass
|
||||
|
||||
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
|
||||
pass
|
||||
|
||||
def consume_ndarray_i32_1(n: ndarray[int32, 1]):
|
||||
pass
|
||||
|
||||
def consume_ndarray_2(n: ndarray[float, 2]):
|
||||
pass
|
||||
|
||||
def test_ndarray_ctor():
|
||||
n = np_ndarray([1])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_empty():
|
||||
n = np_empty([1])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_zeros():
|
||||
n = np_zeros([1])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_ones():
|
||||
n = np_ones([1])
|
||||
consume_ndarray_1(n)
|
||||
|
||||
def test_ndarray_full():
|
||||
n_float = np_full([1], 2.0)
|
||||
consume_ndarray_1(n_float)
|
||||
n_i32 = np_full([1], 2)
|
||||
consume_ndarray_i32_1(n_i32)
|
||||
|
||||
def test_ndarray_eye():
|
||||
n = np_eye(2)
|
||||
consume_ndarray_2(n)
|
||||
|
||||
def test_ndarray_identity():
|
||||
n = np_identity(2)
|
||||
consume_ndarray_2(n)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
test_ndarray_zeros()
|
||||
test_ndarray_ones()
|
||||
test_ndarray_full()
|
||||
test_ndarray_eye()
|
||||
test_ndarray_identity()
|
||||
|
||||
return 0
|
|
@ -286,6 +286,7 @@ fn main() {
|
|||
// The default behavior for -O<n> where n>3 defaults to O3 for both Clang and GCC
|
||||
_ => OptimizationLevel::Aggressive,
|
||||
};
|
||||
const SIZE_T: u32 = 64;
|
||||
|
||||
let program = match fs::read_to_string(file_name.clone()) {
|
||||
Ok(program) => program,
|
||||
|
@ -295,9 +296,9 @@ fn main() {
|
|||
}
|
||||
};
|
||||
|
||||
let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0;
|
||||
let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0;
|
||||
let (mut composer, builtins_def, builtins_ty) =
|
||||
TopLevelComposer::new(vec![], ComposerConfig::default());
|
||||
TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T);
|
||||
|
||||
let internal_resolver: Arc<ResolverInternal> = ResolverInternal {
|
||||
id_to_type: builtins_ty.into(),
|
||||
|
@ -400,7 +401,7 @@ fn main() {
|
|||
membuffer.lock().push(buffer);
|
||||
})));
|
||||
let threads = (0..threads)
|
||||
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), 64)))
|
||||
.map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T)))
|
||||
.collect();
|
||||
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
|
||||
registry.add_task(task);
|
||||
|
|
Loading…
Reference in New Issue