forked from M-Labs/nac3
core: Use TObj for NDArray
This commit is contained in:
parent
c3db6297d9
commit
234a6bde2a
|
@ -397,9 +397,6 @@ fn gen_rpc_tag(
|
||||||
buffer.push(b'l');
|
buffer.push(b'l');
|
||||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||||
}
|
}
|
||||||
TNDArray { .. } => {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -660,14 +657,6 @@ pub fn attributes_writeback(
|
||||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
TypeEnum::TNDArray { ty: elem_ty, .. } => {
|
|
||||||
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
|
||||||
let pydict = PyDict::new(py);
|
|
||||||
pydict.set_item("obj", val)?;
|
|
||||||
host_attributes.append(pydict)?;
|
|
||||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,12 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace};
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{CodeGenContext, CodeGenerator},
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
|
||||||
toplevel::{DefinitionId, TopLevelDef},
|
toplevel::{
|
||||||
|
DefinitionId,
|
||||||
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
|
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||||
|
TopLevelDef,
|
||||||
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::PrimitiveStore,
|
type_inferencer::PrimitiveStore,
|
||||||
typedef::{Type, TypeEnum, Unifier},
|
typedef::{Type, TypeEnum, Unifier},
|
||||||
|
@ -306,7 +311,7 @@ impl InnerResolver {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
let var = unifier.get_dummy_var().0;
|
let var = unifier.get_dummy_var().0;
|
||||||
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0;
|
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0;
|
||||||
let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims });
|
let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims));
|
||||||
Ok(Ok((ndarray, false)))
|
Ok(Ok((ndarray, false)))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
|
@ -452,7 +457,7 @@ impl InnerResolver {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { .. } => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
if args.len() != 2 {
|
if args.len() != 2 {
|
||||||
return Ok(Err(format!(
|
return Ok(Err(format!(
|
||||||
"type list needs exactly 2 type parameters, found {}",
|
"type list needs exactly 2 type parameters, found {}",
|
||||||
|
@ -648,11 +653,12 @@ impl InnerResolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
(TypeEnum::TNDArray { ty, ndims }, false) => {
|
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty);
|
||||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
&*unifier.get_ty(*ty),
|
&*unifier.get_ty(ty),
|
||||||
TypeEnum::TVar { fields: None, range, .. }
|
TypeEnum::TVar { fields: None, range, .. }
|
||||||
if range.is_empty()
|
if range.is_empty()
|
||||||
));
|
));
|
||||||
|
@ -661,8 +667,17 @@ impl InnerResolver {
|
||||||
let actual_ty =
|
let actual_ty =
|
||||||
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||||
match actual_ty {
|
match actual_ty {
|
||||||
Ok(t) => match unifier.unify(*ty, t) {
|
Ok(t) => match unifier.unify(ty, t) {
|
||||||
Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
|
Ok(()) => {
|
||||||
|
let ndarray_ty = make_ndarray_ty(
|
||||||
|
unifier,
|
||||||
|
primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(Ok(ndarray_ty))
|
||||||
|
}
|
||||||
Err(e) => Ok(Err(format!(
|
Err(e) => Ok(Err(format!(
|
||||||
"type error ({}) for the ndarray",
|
"type error ({}) for the ndarray",
|
||||||
e.to_display(unifier),
|
e.to_display(unifier),
|
||||||
|
|
|
@ -47,10 +47,6 @@ pub enum ConcreteTypeEnum {
|
||||||
TList {
|
TList {
|
||||||
ty: ConcreteType,
|
ty: ConcreteType,
|
||||||
},
|
},
|
||||||
TNDArray {
|
|
||||||
ty: ConcreteType,
|
|
||||||
ndims: ConcreteType,
|
|
||||||
},
|
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||||
|
@ -171,10 +167,6 @@ impl ConcreteTypeStore {
|
||||||
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
|
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
|
||||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||||
},
|
},
|
||||||
TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray {
|
|
||||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
|
||||||
ndims: self.from_unifier_type(unifier, primitives, *ndims, cache),
|
|
||||||
},
|
|
||||||
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
fields: fields
|
fields: fields
|
||||||
|
@ -268,12 +260,6 @@ impl ConcreteTypeStore {
|
||||||
ConcreteTypeEnum::TList { ty } => {
|
ConcreteTypeEnum::TList { ty } => {
|
||||||
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||||
}
|
}
|
||||||
ConcreteTypeEnum::TNDArray { ty, ndims } => {
|
|
||||||
TypeEnum::TNDArray {
|
|
||||||
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
|
|
||||||
ndims: self.to_unifier_type(unifier, primitives, *ndims, cache),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ConcreteTypeEnum::TVirtual { ty } => {
|
ConcreteTypeEnum::TVirtual { ty } => {
|
||||||
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,12 @@ use crate::{
|
||||||
CodeGenContext, CodeGenTask,
|
CodeGenContext, CodeGenTask,
|
||||||
},
|
},
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{DefinitionId, TopLevelDef},
|
toplevel::{
|
||||||
|
DefinitionId,
|
||||||
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
|
numpy::make_ndarray_ty,
|
||||||
|
TopLevelDef,
|
||||||
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
|
||||||
magic_methods::{binop_name, binop_assign_name},
|
magic_methods::{binop_name, binop_assign_name},
|
||||||
|
@ -181,7 +186,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
||||||
&mut self.unifier,
|
&mut self.unifier,
|
||||||
self.top_level,
|
self.top_level,
|
||||||
&mut self.type_cache,
|
&mut self.type_cache,
|
||||||
&self.primitives,
|
|
||||||
ty,
|
ty,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -1204,23 +1208,25 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
|
||||||
SymbolValue::U64(v) => Ok(v),
|
SymbolValue::U64(v) => Ok(v),
|
||||||
SymbolValue::U32(v) => Ok(v as u64),
|
SymbolValue::U32(v) => Ok(v as u64),
|
||||||
SymbolValue::I32(v) => u64::try_from(v)
|
SymbolValue::I32(v) => u64::try_from(v)
|
||||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
|
||||||
SymbolValue::I64(v) => u64::try_from(v)
|
SymbolValue::I64(v) => u64::try_from(v)
|
||||||
.map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")),
|
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>, _>>()?;
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
assert!(!ndims.is_empty());
|
assert!(!ndims.is_empty());
|
||||||
|
|
||||||
let ndarray_ty_enum = TypeEnum::TNDArray {
|
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
|
||||||
ty,
|
|
||||||
ndims: ctx.unifier.get_fresh_literal(
|
|
||||||
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||||
None,
|
None,
|
||||||
),
|
);
|
||||||
};
|
let ndarray_ty = make_ndarray_ty(
|
||||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
&mut ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
|
Some(ty),
|
||||||
|
Some(ndarray_ndims_ty),
|
||||||
|
);
|
||||||
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
|
||||||
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type();
|
||||||
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum();
|
||||||
|
@ -1963,7 +1969,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
v.get_data().get(ctx, generator, index, None).into()
|
v.get_data().get(ctx, generator, index, None).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { ty, ndims } => {
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
let (ty, ndims) = params.iter()
|
||||||
|
.sorted_by_key(|(var_id, _)| *var_id)
|
||||||
|
.map(|(_, ty)| ty)
|
||||||
|
.collect_tuple()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
||||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
symbol_resolver::{StaticValue, SymbolResolver},
|
symbol_resolver::{StaticValue, SymbolResolver},
|
||||||
toplevel::{TopLevelContext, TopLevelDef},
|
toplevel::{
|
||||||
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
|
numpy::unpack_ndarray_tvars,
|
||||||
|
TopLevelContext,
|
||||||
|
TopLevelDef,
|
||||||
|
},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::{CodeLocation, PrimitiveStore},
|
type_inferencer::{CodeLocation, PrimitiveStore},
|
||||||
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
|
||||||
|
@ -417,7 +422,6 @@ fn get_llvm_type<'ctx>(
|
||||||
unifier: &mut Unifier,
|
unifier: &mut Unifier,
|
||||||
top_level: &TopLevelContext,
|
top_level: &TopLevelContext,
|
||||||
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
type_cache: &mut HashMap<Type, BasicTypeEnum<'ctx>>,
|
||||||
primitives: &PrimitiveStore,
|
|
||||||
ty: Type,
|
ty: Type,
|
||||||
) -> BasicTypeEnum<'ctx> {
|
) -> BasicTypeEnum<'ctx> {
|
||||||
use TypeEnum::*;
|
use TypeEnum::*;
|
||||||
|
@ -427,28 +431,50 @@ fn get_llvm_type<'ctx>(
|
||||||
let ty_enum = unifier.get_ty(ty);
|
let ty_enum = unifier.get_ty(ty);
|
||||||
let result = match &*ty_enum {
|
let result = match &*ty_enum {
|
||||||
TObj { obj_id, fields, .. } => {
|
TObj { obj_id, fields, .. } => {
|
||||||
// check to avoid treating primitives other than Option as classes
|
// check to avoid treating non-class primitives as classes
|
||||||
if obj_id.0 <= 10 {
|
if obj_id.0 <= PRIMITIVE_DEF_IDS.max_id().0 {
|
||||||
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
|
return match &*unifier.get_ty_immutable(ty) {
|
||||||
{
|
TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.option => {
|
||||||
(
|
get_llvm_type(
|
||||||
TObj { obj_id, params, .. },
|
|
||||||
TObj { obj_id: opt_id, .. },
|
|
||||||
) if *obj_id == *opt_id => {
|
|
||||||
return get_llvm_type(
|
|
||||||
ctx,
|
ctx,
|
||||||
module,
|
module,
|
||||||
generator,
|
generator,
|
||||||
unifier,
|
unifier,
|
||||||
top_level,
|
top_level,
|
||||||
type_cache,
|
type_cache,
|
||||||
primitives,
|
|
||||||
*params.iter().next().unwrap().1,
|
*params.iter().next().unwrap().1,
|
||||||
)
|
)
|
||||||
.ptr_type(AddressSpace::default())
|
.ptr_type(AddressSpace::default())
|
||||||
.into();
|
.into()
|
||||||
}
|
}
|
||||||
_ => unreachable!("must be option type"),
|
|
||||||
|
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let (dtype, _) = unpack_ndarray_tvars(unifier, ty);
|
||||||
|
let element_type = get_llvm_type(
|
||||||
|
ctx,
|
||||||
|
module,
|
||||||
|
generator,
|
||||||
|
unifier,
|
||||||
|
top_level,
|
||||||
|
type_cache,
|
||||||
|
dtype,
|
||||||
|
);
|
||||||
|
|
||||||
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
|
||||||
|
//
|
||||||
|
// * num_dims: Number of dimensions in the array
|
||||||
|
// * dims: Pointer to an array containing the size of each dimension
|
||||||
|
// * data: Pointer to an array containing the array data
|
||||||
|
let fields = [
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
||||||
|
element_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
];
|
||||||
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!("LLVM type for primitive {} is missing", unifier.stringify(ty)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// a struct with fields in the order of declaration
|
// a struct with fields in the order of declaration
|
||||||
|
@ -477,7 +503,6 @@ fn get_llvm_type<'ctx>(
|
||||||
unifier,
|
unifier,
|
||||||
top_level,
|
top_level,
|
||||||
type_cache,
|
type_cache,
|
||||||
primitives,
|
|
||||||
fields[&f.0].0,
|
fields[&f.0].0,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -493,7 +518,7 @@ fn get_llvm_type<'ctx>(
|
||||||
.iter()
|
.iter()
|
||||||
.map(|ty| {
|
.map(|ty| {
|
||||||
get_llvm_type(
|
get_llvm_type(
|
||||||
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
|
ctx, module, generator, unifier, top_level, type_cache, *ty,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
.collect_vec();
|
.collect_vec();
|
||||||
|
@ -502,7 +527,7 @@ fn get_llvm_type<'ctx>(
|
||||||
TList { ty } => {
|
TList { ty } => {
|
||||||
// a struct with an integer and a pointer to an array
|
// a struct with an integer and a pointer to an array
|
||||||
let element_type = get_llvm_type(
|
let element_type = get_llvm_type(
|
||||||
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
|
ctx, module, generator, unifier, top_level, type_cache, *ty,
|
||||||
);
|
);
|
||||||
let fields = [
|
let fields = [
|
||||||
element_type.ptr_type(AddressSpace::default()).into(),
|
element_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
@ -510,24 +535,6 @@ fn get_llvm_type<'ctx>(
|
||||||
];
|
];
|
||||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||||
}
|
}
|
||||||
TNDArray { ty, .. } => {
|
|
||||||
let llvm_usize = generator.get_size_type(ctx);
|
|
||||||
let element_type = get_llvm_type(
|
|
||||||
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
|
|
||||||
);
|
|
||||||
|
|
||||||
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
|
|
||||||
//
|
|
||||||
// * num_dims: Number of dimensions in the array
|
|
||||||
// * dims: Pointer to an array containing the size of each dimension
|
|
||||||
// * data: Pointer to an array containing the array data
|
|
||||||
let fields = [
|
|
||||||
llvm_usize.into(),
|
|
||||||
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
|
||||||
element_type.ptr_type(AddressSpace::default()).into(),
|
|
||||||
];
|
|
||||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
|
||||||
}
|
|
||||||
TVirtual { .. } => unimplemented!(),
|
TVirtual { .. } => unimplemented!(),
|
||||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||||
};
|
};
|
||||||
|
@ -561,7 +568,7 @@ fn get_llvm_abi_type<'ctx>(
|
||||||
return if unifier.unioned(ty, primitives.bool) {
|
return if unifier.unioned(ty, primitives.bool) {
|
||||||
ctx.bool_type().into()
|
ctx.bool_type().into()
|
||||||
} else {
|
} else {
|
||||||
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, primitives, ty)
|
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -763,7 +770,6 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
|
||||||
&mut unifier,
|
&mut unifier,
|
||||||
top_level_ctx.as_ref(),
|
top_level_ctx.as_ref(),
|
||||||
&mut type_cache,
|
&mut type_cache,
|
||||||
&primitives,
|
|
||||||
arg.ty,
|
arg.ty,
|
||||||
);
|
);
|
||||||
let alloca = builder
|
let alloca = builder
|
||||||
|
|
|
@ -10,7 +10,12 @@ use crate::{
|
||||||
expr::gen_binop_expr,
|
expr::gen_binop_expr,
|
||||||
gen_in_range_check,
|
gen_in_range_check,
|
||||||
},
|
},
|
||||||
toplevel::{DefinitionId, TopLevelDef},
|
toplevel::{
|
||||||
|
DefinitionId,
|
||||||
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
|
numpy::unpack_ndarray_tvars,
|
||||||
|
TopLevelDef,
|
||||||
|
},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
|
@ -186,7 +191,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
v.get_data().ptr_offset(ctx, generator, index, name)
|
v.get_data().ptr_offset(ctx, generator, index, name)
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeEnum::TNDArray { .. } => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,11 +247,15 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
let value = ListValue::from_ptr_val(value, llvm_usize, None);
|
||||||
let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
|
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
|
||||||
unreachable!()
|
TypeEnum::TList { ty } => *ty,
|
||||||
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let ty = ctx.get_llvm_type(generator, *ty);
|
let ty = ctx.get_llvm_type(generator, ty);
|
||||||
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {
|
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {
|
||||||
return Ok(())
|
return Ok(())
|
||||||
};
|
};
|
||||||
|
|
|
@ -3,16 +3,12 @@ use std::sync::Arc;
|
||||||
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
use std::{collections::HashMap, collections::HashSet, fmt::Display};
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
use crate::typecheck::typedef::TypeEnum;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::CodeGenContext,
|
codegen::{CodeGenContext, CodeGenerator},
|
||||||
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
|
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
|
||||||
};
|
|
||||||
use crate::{
|
|
||||||
codegen::CodeGenerator,
|
|
||||||
typecheck::{
|
typecheck::{
|
||||||
type_inferencer::PrimitiveStore,
|
type_inferencer::PrimitiveStore,
|
||||||
typedef::{Type, Unifier},
|
typedef::{Type, TypeEnum, Unifier},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
|
||||||
|
@ -353,14 +349,13 @@ pub trait SymbolResolver {
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static IDENTIFIER_ID: [StrRef; 13] = [
|
static IDENTIFIER_ID: [StrRef; 12] = [
|
||||||
"int32".into(),
|
"int32".into(),
|
||||||
"int64".into(),
|
"int64".into(),
|
||||||
"float".into(),
|
"float".into(),
|
||||||
"bool".into(),
|
"bool".into(),
|
||||||
"virtual".into(),
|
"virtual".into(),
|
||||||
"list".into(),
|
"list".into(),
|
||||||
"ndarray".into(),
|
|
||||||
"tuple".into(),
|
"tuple".into(),
|
||||||
"str".into(),
|
"str".into(),
|
||||||
"Exception".into(),
|
"Exception".into(),
|
||||||
|
@ -386,13 +381,12 @@ pub fn parse_type_annotation<T>(
|
||||||
let bool_id = ids[3];
|
let bool_id = ids[3];
|
||||||
let virtual_id = ids[4];
|
let virtual_id = ids[4];
|
||||||
let list_id = ids[5];
|
let list_id = ids[5];
|
||||||
let ndarray_id = ids[6];
|
let tuple_id = ids[6];
|
||||||
let tuple_id = ids[7];
|
let str_id = ids[7];
|
||||||
let str_id = ids[8];
|
let exn_id = ids[8];
|
||||||
let exn_id = ids[9];
|
let uint32_id = ids[9];
|
||||||
let uint32_id = ids[10];
|
let uint64_id = ids[10];
|
||||||
let uint64_id = ids[11];
|
let literal_id = ids[11];
|
||||||
let literal_id = ids[12];
|
|
||||||
|
|
||||||
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
||||||
if *id == int32_id {
|
if *id == int32_id {
|
||||||
|
@ -463,21 +457,6 @@ pub fn parse_type_annotation<T>(
|
||||||
} else if *id == list_id {
|
} else if *id == list_id {
|
||||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
|
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
|
||||||
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
||||||
} else if *id == ndarray_id {
|
|
||||||
let Tuple { elts, .. } = &slice.node else {
|
|
||||||
return Err(HashSet::from([
|
|
||||||
String::from("Expected 2 type arguments for ndarray"),
|
|
||||||
]))
|
|
||||||
};
|
|
||||||
if elts.len() < 2 {
|
|
||||||
return Err(HashSet::from([
|
|
||||||
String::from("Expected 2 type arguments for ndarray"),
|
|
||||||
]))
|
|
||||||
}
|
|
||||||
|
|
||||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?;
|
|
||||||
let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?;
|
|
||||||
Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims }))
|
|
||||||
} else if *id == tuple_id {
|
} else if *id == tuple_id {
|
||||||
if let Tuple { elts, .. } = &slice.node {
|
if let Tuple { elts, .. } = &slice.node {
|
||||||
let ty = elts
|
let ty = elts
|
||||||
|
|
|
@ -274,14 +274,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let boolean = primitives.0.bool;
|
let boolean = primitives.0.bool;
|
||||||
let range = primitives.0.range;
|
let range = primitives.0.range;
|
||||||
let string = primitives.0.str;
|
let string = primitives.0.str;
|
||||||
let ndarray = {
|
let ndarray = primitives.0.ndarray;
|
||||||
let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0);
|
let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), None);
|
||||||
primitives.1.add_ty(ndarray_ty)
|
|
||||||
};
|
|
||||||
let ndarray_float = {
|
|
||||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0);
|
|
||||||
primitives.1.add_ty(ndarray_ty_enum)
|
|
||||||
};
|
|
||||||
let ndarray_float_2d = {
|
let ndarray_float_2d = {
|
||||||
let value = match primitives.0.size_t {
|
let value = match primitives.0.size_t {
|
||||||
64 => SymbolValue::U64(2u64),
|
64 => SymbolValue::U64(2u64),
|
||||||
|
@ -293,10 +287,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
loc: None,
|
loc: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
primitives.1.add_ty(TypeEnum::TNDArray {
|
make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(ndims))
|
||||||
ty: float,
|
|
||||||
ndims,
|
|
||||||
})
|
|
||||||
};
|
};
|
||||||
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 });
|
||||||
let num_ty = primitives.1.get_fresh_var_with_range(
|
let num_ty = primitives.1.get_fresh_var_with_range(
|
||||||
|
@ -1352,7 +1343,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let tvar = primitives.1.get_fresh_var(Some("L".into()), None);
|
let tvar = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 });
|
let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 });
|
||||||
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
||||||
let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 });
|
let ndarray = make_ndarray_ty(
|
||||||
|
&mut primitives.1,
|
||||||
|
&primitives.0,
|
||||||
|
Some(tvar.0),
|
||||||
|
Some(ndims.0),
|
||||||
|
);
|
||||||
|
|
||||||
let arg_ty = primitives.1.get_fresh_var_with_range(
|
let arg_ty = primitives.1.get_fresh_var_with_range(
|
||||||
&[list, ndarray, primitives.0.range],
|
&[list, ndarray, primitives.0.range],
|
||||||
|
@ -1404,7 +1400,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { .. } => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let i32_zero = llvm_i32.const_zero();
|
let i32_zero = llvm_i32.const_zero();
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
|
|
||||||
use crate::symbol_resolver::SymbolValue;
|
use crate::symbol_resolver::SymbolValue;
|
||||||
|
use crate::typecheck::typedef::Mapping;
|
||||||
use nac3parser::ast::{Constant, Location};
|
use nac3parser::ast::{Constant, Location};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -194,6 +195,23 @@ impl TopLevelComposer {
|
||||||
params: HashMap::from([(option_type_var.1, option_type_var.0)]),
|
params: HashMap::from([(option_type_var.1, option_type_var.0)]),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let size_t_ty = match size_t {
|
||||||
|
32 => uint32,
|
||||||
|
64 => uint64,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||||
|
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||||
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||||
|
fields: Mapping::new(),
|
||||||
|
params: Mapping::from([
|
||||||
|
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||||
|
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||||
|
])
|
||||||
|
});
|
||||||
|
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
int64,
|
int64,
|
||||||
|
@ -206,6 +224,7 @@ impl TopLevelComposer {
|
||||||
str,
|
str,
|
||||||
exception,
|
exception,
|
||||||
option,
|
option,
|
||||||
|
ndarray,
|
||||||
size_t,
|
size_t,
|
||||||
};
|
};
|
||||||
unifier.put_primitive_store(&primitives);
|
unifier.put_primitive_store(&primitives);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
|
use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}};
|
||||||
use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue};
|
use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue};
|
||||||
|
use itertools::Itertools;
|
||||||
use nac3parser::ast::StrRef;
|
use nac3parser::ast::StrRef;
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
|
@ -15,10 +16,68 @@ use crate::{
|
||||||
stmt::gen_for_callback
|
stmt::gen_for_callback
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
toplevel::DefinitionId,
|
toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS},
|
||||||
typecheck::typedef::{FunSignature, Type, TypeEnum},
|
typecheck::{
|
||||||
|
type_inferencer::PrimitiveStore,
|
||||||
|
typedef::{FunSignature, Mapping, Type, TypeEnum, Unifier},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Creates a `ndarray` [`Type`] with the given type arguments.
|
||||||
|
///
|
||||||
|
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
|
||||||
|
/// specialized.
|
||||||
|
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
|
||||||
|
/// specialized.
|
||||||
|
pub fn make_ndarray_ty(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
primitives: &PrimitiveStore,
|
||||||
|
dtype: Option<Type>,
|
||||||
|
ndims: Option<Type>,
|
||||||
|
) -> Type {
|
||||||
|
let ndarray = primitives.ndarray;
|
||||||
|
|
||||||
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||||
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||||
|
};
|
||||||
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
|
||||||
|
let tvar_ids = params.iter()
|
||||||
|
.map(|(obj_id, _)| *obj_id)
|
||||||
|
.sorted()
|
||||||
|
.collect_vec();
|
||||||
|
debug_assert_eq!(tvar_ids.len(), 2);
|
||||||
|
|
||||||
|
let mut tvar_subst = Mapping::new();
|
||||||
|
if let Some(dtype) = dtype {
|
||||||
|
tvar_subst.insert(tvar_ids[0], dtype);
|
||||||
|
}
|
||||||
|
if let Some(ndims) = ndims {
|
||||||
|
tvar_subst.insert(tvar_ids[1], ndims);
|
||||||
|
}
|
||||||
|
|
||||||
|
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
|
||||||
|
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
|
||||||
|
pub fn unpack_ndarray_tvars(
|
||||||
|
unifier: &mut Unifier,
|
||||||
|
ndarray: Type,
|
||||||
|
) -> (Type, Type) {
|
||||||
|
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
|
||||||
|
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
|
||||||
|
};
|
||||||
|
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
|
||||||
|
debug_assert_eq!(params.len(), 2);
|
||||||
|
|
||||||
|
params.iter()
|
||||||
|
.sorted_by_key(|(obj_id, _)| *obj_id)
|
||||||
|
.map(|(_, ty)| *ty)
|
||||||
|
.collect_tuple()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates an `NDArray` instance from a constant shape.
|
/// Creates an `NDArray` instance from a constant shape.
|
||||||
///
|
///
|
||||||
/// * `elem_ty` - The element type of the `NDArray`.
|
/// * `elem_ty` - The element type of the `NDArray`.
|
||||||
|
@ -29,8 +88,7 @@ fn create_ndarray_const_shape<'ctx>(
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ArrayValue<'ctx>
|
shape: ArrayValue<'ctx>
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None);
|
||||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
@ -147,8 +205,12 @@ fn call_ndarray_empty_impl<'ctx>(
|
||||||
elem_ty: Type,
|
elem_ty: Type,
|
||||||
shape: ListValue<'ctx>,
|
shape: ListValue<'ctx>,
|
||||||
) -> Result<NDArrayValue<'ctx>, String> {
|
) -> Result<NDArrayValue<'ctx>, String> {
|
||||||
let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives);
|
let ndarray_ty = make_ndarray_ty(
|
||||||
let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum);
|
&mut ctx.unifier,
|
||||||
|
&ctx.primitives,
|
||||||
|
Some(elem_ty),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [29]\n}\n",
|
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [28]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar18]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar18\"]\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B[typevar17]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar17\"]\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
||||||
[
|
[
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [31]\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [30]\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [36]\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [35]\n}\n",
|
||||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
||||||
expression: res_vec
|
expression: res_vec
|
||||||
---
|
---
|
||||||
[
|
[
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar17, typevar18]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar17\", \"typevar18\"]\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A[typevar16, typevar17]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar16\", \"typevar17\"]\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n",
|
||||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
||||||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [37]\n}\n",
|
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [36]\n}\n",
|
||||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [45]\n}\n",
|
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [44]\n}\n",
|
||||||
]
|
]
|
||||||
|
|
|
@ -492,24 +492,11 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
(*name, (subst_ty, *mutability))
|
(*name, (subst_ty, *mutability))
|
||||||
}));
|
}));
|
||||||
let need_subst = !subst.is_empty();
|
let need_subst = !subst.is_empty();
|
||||||
let ty = if obj_id == &PRIMITIVE_DEF_IDS.ndarray {
|
let ty = unifier.add_ty(TypeEnum::TObj {
|
||||||
assert_eq!(subst.len(), 2);
|
|
||||||
let tv_tys = subst.iter()
|
|
||||||
.sorted_by_key(|(k, _)| *k)
|
|
||||||
.map(|(_, v)| v)
|
|
||||||
.collect_vec();
|
|
||||||
|
|
||||||
unifier.add_ty(TypeEnum::TNDArray {
|
|
||||||
ty: *tv_tys[0],
|
|
||||||
ndims: *tv_tys[1],
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
unifier.add_ty(TypeEnum::TObj {
|
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
fields: tobj_fields,
|
fields: tobj_fields,
|
||||||
params: subst,
|
params: subst,
|
||||||
})
|
});
|
||||||
};
|
|
||||||
if need_subst {
|
if need_subst {
|
||||||
if let Some(wl) = subst_list.as_mut() {
|
if let Some(wl) = subst_list.as_mut() {
|
||||||
wl.push(ty);
|
wl.push(ty);
|
||||||
|
|
|
@ -5,7 +5,14 @@ use std::{cell::RefCell, sync::Arc};
|
||||||
|
|
||||||
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
|
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier};
|
||||||
use super::{magic_methods::*, typedef::CallId};
|
use super::{magic_methods::*, typedef::CallId};
|
||||||
use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext};
|
use crate::{
|
||||||
|
symbol_resolver::{SymbolResolver, SymbolValue},
|
||||||
|
toplevel::{
|
||||||
|
helper::PRIMITIVE_DEF_IDS,
|
||||||
|
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
|
||||||
|
TopLevelContext,
|
||||||
|
},
|
||||||
|
};
|
||||||
use itertools::izip;
|
use itertools::izip;
|
||||||
use nac3parser::ast::{
|
use nac3parser::ast::{
|
||||||
self,
|
self,
|
||||||
|
@ -47,6 +54,7 @@ pub struct PrimitiveStore {
|
||||||
pub str: Type,
|
pub str: Type,
|
||||||
pub exception: Type,
|
pub exception: Type,
|
||||||
pub option: Type,
|
pub option: Type,
|
||||||
|
pub ndarray: Type,
|
||||||
pub size_t: u32,
|
pub size_t: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -226,7 +234,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
} else {
|
} else {
|
||||||
let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) {
|
let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }),
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }),
|
||||||
TypeEnum::TNDArray { .. } => todo!(),
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => todo!(),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
};
|
};
|
||||||
self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?;
|
self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?;
|
||||||
|
@ -916,10 +924,12 @@ impl<'a> Inferencer<'a> {
|
||||||
vec![SymbolValue::U64(ndims)],
|
vec![SymbolValue::U64(ndims)],
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
let ret = make_ndarray_ty(
|
||||||
ty: self.primitives.float,
|
self.unifier,
|
||||||
ndims
|
self.primitives,
|
||||||
});
|
Some(self.primitives.float),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![
|
args: vec![
|
||||||
FuncArg {
|
FuncArg {
|
||||||
|
@ -966,11 +976,12 @@ impl<'a> Inferencer<'a> {
|
||||||
vec![SymbolValue::U64(ndims)],
|
vec![SymbolValue::U64(ndims)],
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
let ret = make_ndarray_ty(
|
||||||
let ret = self.unifier.add_ty(TypeEnum::TNDArray {
|
self.unifier,
|
||||||
ty,
|
self.primitives,
|
||||||
ndims
|
Some(ty),
|
||||||
});
|
Some(ndims),
|
||||||
|
);
|
||||||
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![
|
args: vec![
|
||||||
FuncArg {
|
FuncArg {
|
||||||
|
@ -1252,11 +1263,16 @@ impl<'a> Inferencer<'a> {
|
||||||
TypeEnum::TVar { is_const_generic: false, .. }
|
TypeEnum::TVar { is_const_generic: false, .. }
|
||||||
));
|
));
|
||||||
|
|
||||||
let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims });
|
let constrained_ty = make_ndarray_ty(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(dummy_tvar),
|
||||||
|
Some(ndims),
|
||||||
|
);
|
||||||
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
|
||||||
|
|
||||||
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
|
||||||
panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(ndims))
|
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
|
||||||
};
|
};
|
||||||
|
|
||||||
let ndims = values.iter()
|
let ndims = values.iter()
|
||||||
|
@ -1264,10 +1280,10 @@ impl<'a> Inferencer<'a> {
|
||||||
SymbolValue::U64(v) => Ok(v),
|
SymbolValue::U64(v) => Ok(v),
|
||||||
SymbolValue::U32(v) => Ok(v as u64),
|
SymbolValue::U32(v) => Ok(v as u64),
|
||||||
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||||
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
format!("Expected non-negative literal for ndarray.ndims, got {v}"),
|
||||||
])),
|
])),
|
||||||
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
|
||||||
format!("Expected non-negative literal for TNDArray.ndims, got {v}"),
|
format!("Expected non-negative literal for ndarray.ndims, got {v}"),
|
||||||
])),
|
])),
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
})
|
})
|
||||||
|
@ -1292,10 +1308,12 @@ impl<'a> Inferencer<'a> {
|
||||||
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray {
|
let subscripted_ty = make_ndarray_ty(
|
||||||
ty: dummy_tvar,
|
self.unifier,
|
||||||
ndims: ndims_min_one_ty,
|
self.primitives,
|
||||||
});
|
Some(dummy_tvar),
|
||||||
|
Some(ndims_min_one_ty),
|
||||||
|
);
|
||||||
|
|
||||||
Ok(subscripted_ty)
|
Ok(subscripted_ty)
|
||||||
}
|
}
|
||||||
|
@ -1315,16 +1333,24 @@ impl<'a> Inferencer<'a> {
|
||||||
}
|
}
|
||||||
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }),
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||||
|
|
||||||
|
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
|
||||||
|
}
|
||||||
|
|
||||||
_ => unreachable!()
|
_ => unreachable!()
|
||||||
};
|
};
|
||||||
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||||
Ok(list_like_ty)
|
Ok(list_like_ty)
|
||||||
}
|
}
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
self.infer_subscript_ndarray(value, ty, *ndims)
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
} else {
|
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||||
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
// the index is a constant, so value can be a sequence.
|
// the index is a constant, so value can be a sequence.
|
||||||
let ind: Option<i32> = (*val).try_into().ok();
|
let ind: Option<i32> = (*val).try_into().ok();
|
||||||
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
|
||||||
|
@ -1338,6 +1364,7 @@ impl<'a> Inferencer<'a> {
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||||
|
@ -1351,9 +1378,11 @@ impl<'a> Inferencer<'a> {
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { ndims, .. } => {
|
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
|
||||||
|
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
|
||||||
|
|
||||||
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
|
||||||
self.infer_subscript_ndarray(value, ty, *ndims)
|
self.infer_subscript_ndarray(value, ty, ndims)
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
|
|
@ -135,6 +135,11 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: HashMap::new(),
|
params: HashMap::new(),
|
||||||
});
|
});
|
||||||
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||||
|
fields: HashMap::new(),
|
||||||
|
params: HashMap::new(),
|
||||||
|
});
|
||||||
let primitives = PrimitiveStore {
|
let primitives = PrimitiveStore {
|
||||||
int32,
|
int32,
|
||||||
int64,
|
int64,
|
||||||
|
@ -147,6 +152,7 @@ impl TestEnvironment {
|
||||||
uint32,
|
uint32,
|
||||||
uint64,
|
uint64,
|
||||||
option,
|
option,
|
||||||
|
ndarray,
|
||||||
size_t: 64,
|
size_t: 64,
|
||||||
};
|
};
|
||||||
unifier.put_primitive_store(&primitives);
|
unifier.put_primitive_store(&primitives);
|
||||||
|
@ -262,6 +268,11 @@ impl TestEnvironment {
|
||||||
fields: HashMap::new(),
|
fields: HashMap::new(),
|
||||||
params: HashMap::new(),
|
params: HashMap::new(),
|
||||||
});
|
});
|
||||||
|
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||||
|
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||||
|
fields: HashMap::new(),
|
||||||
|
params: HashMap::new(),
|
||||||
|
});
|
||||||
identifier_mapping.insert("None".into(), none);
|
identifier_mapping.insert("None".into(), none);
|
||||||
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -296,6 +307,7 @@ impl TestEnvironment {
|
||||||
uint32,
|
uint32,
|
||||||
uint64,
|
uint64,
|
||||||
option,
|
option,
|
||||||
|
ndarray,
|
||||||
size_t: 64,
|
size_t: 64,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -159,11 +159,6 @@ pub enum TypeEnum {
|
||||||
ty: Type,
|
ty: Type,
|
||||||
},
|
},
|
||||||
|
|
||||||
TNDArray {
|
|
||||||
ty: Type,
|
|
||||||
ndims: Type,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// An object type.
|
/// An object type.
|
||||||
TObj {
|
TObj {
|
||||||
/// The [DefintionId] of this object type.
|
/// The [DefintionId] of this object type.
|
||||||
|
@ -198,34 +193,12 @@ impl TypeEnum {
|
||||||
TypeEnum::TLiteral { .. } => "TConstant",
|
TypeEnum::TLiteral { .. } => "TConstant",
|
||||||
TypeEnum::TTuple { .. } => "TTuple",
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
TypeEnum::TList { .. } => "TList",
|
TypeEnum::TList { .. } => "TList",
|
||||||
TypeEnum::TNDArray { .. } => "TNDArray",
|
|
||||||
TypeEnum::TObj { .. } => "TObj",
|
TypeEnum::TObj { .. } => "TObj",
|
||||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||||
TypeEnum::TCall { .. } => "TCall",
|
TypeEnum::TCall { .. } => "TCall",
|
||||||
TypeEnum::TFunc { .. } => "TFunc",
|
TypeEnum::TFunc { .. } => "TFunc",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a [`TypeEnum`] representing a generic `ndarray` type.
|
|
||||||
///
|
|
||||||
/// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic.
|
|
||||||
/// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic.
|
|
||||||
#[must_use]
|
|
||||||
pub fn ndarray(
|
|
||||||
unifier: &mut Unifier,
|
|
||||||
dtype: Option<Type>,
|
|
||||||
ndims: Option<Type>,
|
|
||||||
primitives: &PrimitiveStore
|
|
||||||
) -> TypeEnum {
|
|
||||||
let dtype = dtype.unwrap_or_else(|| unifier.get_fresh_var(Some("T".into()), None).0);
|
|
||||||
let ndims = ndims
|
|
||||||
.unwrap_or_else(|| unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None).0);
|
|
||||||
|
|
||||||
TypeEnum::TNDArray {
|
|
||||||
ty: dtype,
|
|
||||||
ndims,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
pub type SharedUnifier = Arc<Mutex<(UnificationTable<TypeEnum>, u32, Vec<Call>)>>;
|
||||||
|
@ -445,9 +418,6 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => self
|
TypeEnum::TList { ty } => self
|
||||||
.get_instantiations(*ty)
|
.get_instantiations(*ty)
|
||||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||||
TypeEnum::TNDArray { ty, ndims } => self
|
|
||||||
.get_instantiations(*ty)
|
|
||||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()),
|
|
||||||
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||||
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||||
}),
|
}),
|
||||||
|
@ -505,8 +475,7 @@ impl Unifier {
|
||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty }
|
TList { ty }
|
||||||
| TVirtual { ty }
|
| TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
| TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
|
||||||
|
|
||||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
TObj { params: vars, .. } => {
|
TObj { params: vars, .. } => {
|
||||||
|
@ -752,7 +721,7 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty } | TNDArray { ty, .. }) => {
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
|
||||||
for (k, v) in fields {
|
for (k, v) in fields {
|
||||||
match *k {
|
match *k {
|
||||||
RecordKey::Int(_) => {
|
RecordKey::Int(_) => {
|
||||||
|
@ -792,7 +761,6 @@ impl Unifier {
|
||||||
|
|
||||||
// If the types don't match, try to implicitly promote integers
|
// If the types don't match, try to implicitly promote integers
|
||||||
if !self.unioned(ty, value_ty) {
|
if !self.unioned(ty, value_ty) {
|
||||||
|
|
||||||
let num_val = match *value {
|
let num_val = match *value {
|
||||||
SymbolValue::I32(v) => v as i128,
|
SymbolValue::I32(v) => v as i128,
|
||||||
SymbolValue::I64(v) => v as i128,
|
SymbolValue::I64(v) => v as i128,
|
||||||
|
@ -864,15 +832,6 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
|
||||||
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
|
||||||
return Self::incompatible_types(a, b)
|
|
||||||
}
|
|
||||||
if self.unify_impl(*ndims1, *ndims2, false).is_err() {
|
|
||||||
return Self::incompatible_types(a, b)
|
|
||||||
}
|
|
||||||
self.set_a_to_b(a, b);
|
|
||||||
}
|
|
||||||
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
||||||
for (k, field) in map {
|
for (k, field) in map {
|
||||||
match *k {
|
match *k {
|
||||||
|
@ -1120,13 +1079,6 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { ty, ndims } => {
|
|
||||||
format!(
|
|
||||||
"ndarray[{}, {}]",
|
|
||||||
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
|
|
||||||
self.internal_stringify(*ndims, obj_to_name, var_to_name, notes),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
TypeEnum::TVirtual { ty } => {
|
TypeEnum::TVirtual { ty } => {
|
||||||
format!(
|
format!(
|
||||||
"virtual[{}]",
|
"virtual[{}]",
|
||||||
|
@ -1264,19 +1216,6 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||||
}
|
}
|
||||||
TypeEnum::TNDArray { ty, ndims } => {
|
|
||||||
let new_ty = self.subst_impl(*ty, mapping, cache);
|
|
||||||
let new_ndims = self.subst_impl(*ndims, mapping, cache);
|
|
||||||
|
|
||||||
if new_ty.is_some() || new_ndims.is_some() {
|
|
||||||
Some(self.add_ty(TypeEnum::TNDArray {
|
|
||||||
ty: new_ty.unwrap_or(*ty),
|
|
||||||
ndims: new_ndims.unwrap_or(*ndims)
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TypeEnum::TVirtual { ty } => self
|
TypeEnum::TVirtual { ty } => self
|
||||||
.subst_impl(*ty, mapping, cache)
|
.subst_impl(*ty, mapping, cache)
|
||||||
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
||||||
|
@ -1447,19 +1386,6 @@ impl Unifier {
|
||||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
||||||
}
|
}
|
||||||
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
|
||||||
let ty = self.get_intersection(*ty1, *ty2)?;
|
|
||||||
let ndims = self.get_intersection(*ndims1, *ndims2)?;
|
|
||||||
|
|
||||||
Ok(if ty.is_some() || ndims.is_some() {
|
|
||||||
Some(self.add_ty(TNDArray {
|
|
||||||
ty: ty.unwrap_or(*ty1),
|
|
||||||
ndims: ndims.unwrap_or(*ndims1),
|
|
||||||
}))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
})
|
|
||||||
}
|
|
||||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ impl Unifier {
|
||||||
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
TypeEnum::TObj { obj_id: id1, params: params1, .. },
|
||||||
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
TypeEnum::TObj { obj_id: id2, params: params2, .. },
|
||||||
) => id1 == id2 && self.map_eq(params1, params2),
|
) => id1 == id2 && self.map_eq(params1, params2),
|
||||||
// TNDArray, TLiteral, TCall and TFunc are not yet implemented
|
// TLiteral, TCall and TFunc are not yet implemented
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue