forked from M-Labs/nac3
core: Initial infrastructure for ndarray
This commit is contained in:
parent
03870f222d
commit
c395472094
|
@ -400,6 +400,9 @@ fn gen_rpc_tag(
|
||||||
buffer.push(b'l');
|
buffer.push(b'l');
|
||||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
gen_rpc_tag(ctx, *ty, buffer)?;
|
||||||
}
|
}
|
||||||
|
TNDArray { .. } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -673,6 +676,14 @@ pub fn attributes_writeback(
|
||||||
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
TypeEnum::TNDArray { ty: elem_ty, .. } => {
|
||||||
|
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
|
||||||
|
let pydict = PyDict::new(py);
|
||||||
|
pydict.set_item("obj", val)?;
|
||||||
|
host_attributes.append(pydict)?;
|
||||||
|
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
|
||||||
|
}
|
||||||
|
},
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,6 +85,7 @@ pub struct PrimitivePythonId {
|
||||||
float64: u64,
|
float64: u64,
|
||||||
bool: u64,
|
bool: u64,
|
||||||
list: u64,
|
list: u64,
|
||||||
|
ndarray: u64,
|
||||||
tuple: u64,
|
tuple: u64,
|
||||||
typevar: u64,
|
typevar: u64,
|
||||||
const_generic_marker: u64,
|
const_generic_marker: u64,
|
||||||
|
@ -879,6 +880,7 @@ impl Nac3 {
|
||||||
float: get_attr_id(builtins_mod, "float"),
|
float: get_attr_id(builtins_mod, "float"),
|
||||||
float64: get_attr_id(numpy_mod, "float64"),
|
float64: get_attr_id(numpy_mod, "float64"),
|
||||||
list: get_attr_id(builtins_mod, "list"),
|
list: get_attr_id(builtins_mod, "list"),
|
||||||
|
ndarray: get_attr_id(numpy_mod, "NDArray"),
|
||||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||||
exception: get_attr_id(builtins_mod, "Exception"),
|
exception: get_attr_id(builtins_mod, "Exception"),
|
||||||
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()),
|
||||||
|
|
|
@ -302,6 +302,12 @@ impl InnerResolver {
|
||||||
let var = unifier.get_dummy_var().0;
|
let var = unifier.get_dummy_var().0;
|
||||||
let list = unifier.add_ty(TypeEnum::TList { ty: var });
|
let list = unifier.add_ty(TypeEnum::TList { ty: var });
|
||||||
Ok(Ok((list, false)))
|
Ok(Ok((list, false)))
|
||||||
|
} else if ty_id == self.primitive_ids.ndarray {
|
||||||
|
// do not handle type var param and concrete check here
|
||||||
|
let var = unifier.get_dummy_var().0;
|
||||||
|
let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0;
|
||||||
|
let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims });
|
||||||
|
Ok(Ok((ndarray, false)))
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||||
|
@ -446,6 +452,16 @@ impl InnerResolver {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TypeEnum::TNDArray { .. } => {
|
||||||
|
if args.len() != 2 {
|
||||||
|
return Ok(Err(format!(
|
||||||
|
"type list needs exactly 2 type parameters, found {}",
|
||||||
|
args.len()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let args = match args
|
let args = match args
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -607,7 +623,7 @@ impl InnerResolver {
|
||||||
Err(e) => return Ok(Err(e)),
|
Err(e) => return Ok(Err(e)),
|
||||||
};
|
};
|
||||||
match (&*unifier.get_ty(extracted_ty), inst_check) {
|
match (&*unifier.get_ty(extracted_ty), inst_check) {
|
||||||
// do the instantiation for these three types
|
// do the instantiation for these four types
|
||||||
(TypeEnum::TList { ty }, false) => {
|
(TypeEnum::TList { ty }, false) => {
|
||||||
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
if len == 0 {
|
if len == 0 {
|
||||||
|
@ -632,6 +648,30 @@ impl InnerResolver {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
(TypeEnum::TNDArray { ty, ndims }, false) => {
|
||||||
|
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
|
if len == 0 {
|
||||||
|
assert!(matches!(
|
||||||
|
&*unifier.get_ty(*ty),
|
||||||
|
TypeEnum::TVar { fields: None, range, .. }
|
||||||
|
if range.is_empty()
|
||||||
|
));
|
||||||
|
Ok(Ok(extracted_ty))
|
||||||
|
} else {
|
||||||
|
let actual_ty =
|
||||||
|
self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?;
|
||||||
|
match actual_ty {
|
||||||
|
Ok(t) => match unifier.unify(*ty, t) {
|
||||||
|
Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))),
|
||||||
|
Err(e) => Ok(Err(format!(
|
||||||
|
"type error ({}) for the ndarray",
|
||||||
|
e.to_display(unifier).to_string()
|
||||||
|
))),
|
||||||
|
},
|
||||||
|
Err(e) => Ok(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
(TypeEnum::TTuple { .. }, false) => {
|
(TypeEnum::TTuple { .. }, false) => {
|
||||||
let elements: &PyTuple = obj.downcast()?;
|
let elements: &PyTuple = obj.downcast()?;
|
||||||
let types: Result<Result<Vec<_>, _>, _> = elements
|
let types: Result<Result<Vec<_>, _>, _> = elements
|
||||||
|
@ -898,6 +938,8 @@ impl InnerResolver {
|
||||||
global.set_initializer(&val);
|
global.set_initializer(&val);
|
||||||
|
|
||||||
Ok(Some(global.as_pointer_value().into()))
|
Ok(Some(global.as_pointer_value().into()))
|
||||||
|
} else if ty_id == self.primitive_ids.ndarray {
|
||||||
|
todo!()
|
||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
|
||||||
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else {
|
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else {
|
||||||
|
|
|
@ -47,6 +47,10 @@ pub enum ConcreteTypeEnum {
|
||||||
TList {
|
TList {
|
||||||
ty: ConcreteType,
|
ty: ConcreteType,
|
||||||
},
|
},
|
||||||
|
TNDArray {
|
||||||
|
ty: ConcreteType,
|
||||||
|
ndims: ConcreteType,
|
||||||
|
},
|
||||||
TObj {
|
TObj {
|
||||||
obj_id: DefinitionId,
|
obj_id: DefinitionId,
|
||||||
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
fields: HashMap<StrRef, (ConcreteType, bool)>,
|
||||||
|
@ -167,6 +171,10 @@ impl ConcreteTypeStore {
|
||||||
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
|
TypeEnum::TList { ty } => ConcreteTypeEnum::TList {
|
||||||
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||||
},
|
},
|
||||||
|
TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray {
|
||||||
|
ty: self.from_unifier_type(unifier, primitives, *ty, cache),
|
||||||
|
ndims: self.from_unifier_type(unifier, primitives, *ndims, cache),
|
||||||
|
},
|
||||||
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj {
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
fields: fields
|
fields: fields
|
||||||
|
@ -260,6 +268,12 @@ impl ConcreteTypeStore {
|
||||||
ConcreteTypeEnum::TList { ty } => {
|
ConcreteTypeEnum::TList { ty } => {
|
||||||
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||||
}
|
}
|
||||||
|
ConcreteTypeEnum::TNDArray { ty, ndims } => {
|
||||||
|
TypeEnum::TNDArray {
|
||||||
|
ty: self.to_unifier_type(unifier, primitives, *ty, cache),
|
||||||
|
ndims: self.to_unifier_type(unifier, primitives, *ndims, cache),
|
||||||
|
}
|
||||||
|
}
|
||||||
ConcreteTypeEnum::TVirtual { ty } => {
|
ConcreteTypeEnum::TVirtual { ty } => {
|
||||||
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) }
|
||||||
}
|
}
|
||||||
|
|
|
@ -1846,6 +1846,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
||||||
ctx.build_gep_and_load(arr_ptr, &[index], None).into()
|
ctx.build_gep_and_load(arr_ptr, &[index], None).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
TypeEnum::TNDArray { .. } => {
|
||||||
|
return Err(String::from("subscript operator for ndarray not implemented"))
|
||||||
|
}
|
||||||
TypeEnum::TTuple { .. } => {
|
TypeEnum::TTuple { .. } => {
|
||||||
let index: u32 =
|
let index: u32 =
|
||||||
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
|
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
|
||||||
|
|
|
@ -507,6 +507,24 @@ fn get_llvm_type<'ctx>(
|
||||||
];
|
];
|
||||||
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||||
}
|
}
|
||||||
|
TNDArray { ty, .. } => {
|
||||||
|
let llvm_usize = generator.get_size_type(ctx);
|
||||||
|
let element_type = get_llvm_type(
|
||||||
|
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
|
||||||
|
);
|
||||||
|
|
||||||
|
// struct NDArray { num_dims: size_t, dims: size_t*, data: T* }
|
||||||
|
//
|
||||||
|
// * num_dims: Number of dimensions in the array
|
||||||
|
// * dims: Pointer to an array containing the size of each dimension
|
||||||
|
// * data: Pointer to an array containing the array data
|
||||||
|
let fields = [
|
||||||
|
llvm_usize.into(),
|
||||||
|
llvm_usize.ptr_type(AddressSpace::default()).into(),
|
||||||
|
element_type.ptr_type(AddressSpace::default()).into(),
|
||||||
|
];
|
||||||
|
ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into()
|
||||||
|
}
|
||||||
TVirtual { .. } => unimplemented!(),
|
TVirtual { .. } => unimplemented!(),
|
||||||
_ => unreachable!("{}", ty_enum.get_type_name()),
|
_ => unreachable!("{}", ty_enum.get_type_name()),
|
||||||
};
|
};
|
||||||
|
|
|
@ -99,25 +99,23 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ExprKind::Subscript { value, slice, .. } => {
|
ExprKind::Subscript { value, slice, .. } => {
|
||||||
assert!(matches!(
|
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
|
||||||
ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref(),
|
TypeEnum::TList { .. } => {
|
||||||
TypeEnum::TList { .. },
|
|
||||||
));
|
|
||||||
let i32_type = ctx.ctx.i32_type();
|
let i32_type = ctx.ctx.i32_type();
|
||||||
let zero = i32_type.const_zero();
|
let zero = i32_type.const_zero();
|
||||||
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
|
let v = generator
|
||||||
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
|
.gen_expr(ctx, value)?
|
||||||
} else {
|
.unwrap()
|
||||||
return Ok(None)
|
.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
|
||||||
};
|
.into_pointer_value();
|
||||||
let len = ctx
|
let len = ctx
|
||||||
.build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len"))
|
.build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len"))
|
||||||
.into_int_value();
|
.into_int_value();
|
||||||
let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? {
|
let raw_index = generator
|
||||||
v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
|
.gen_expr(ctx, slice)?
|
||||||
} else {
|
.unwrap()
|
||||||
return Ok(None)
|
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
|
||||||
};
|
.into_int_value();
|
||||||
let raw_index = ctx.builder.build_int_s_extend(
|
let raw_index = ctx.builder.build_int_s_extend(
|
||||||
raw_index,
|
raw_index,
|
||||||
generator.get_size_type(ctx.ctx),
|
generator.get_size_type(ctx.ctx),
|
||||||
|
@ -158,6 +156,14 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
|
||||||
ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or(""))
|
ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or(""))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypeEnum::TNDArray { .. } => {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
@ -203,7 +209,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
|
||||||
let value = value
|
let value = value
|
||||||
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
|
||||||
.into_pointer_value();
|
.into_pointer_value();
|
||||||
let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
|
let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
|
||||||
unreachable!()
|
unreachable!()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -354,13 +354,14 @@ pub trait SymbolResolver {
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static IDENTIFIER_ID: [StrRef; 11] = [
|
static IDENTIFIER_ID: [StrRef; 12] = [
|
||||||
"int32".into(),
|
"int32".into(),
|
||||||
"int64".into(),
|
"int64".into(),
|
||||||
"float".into(),
|
"float".into(),
|
||||||
"bool".into(),
|
"bool".into(),
|
||||||
"virtual".into(),
|
"virtual".into(),
|
||||||
"list".into(),
|
"list".into(),
|
||||||
|
"ndarray".into(),
|
||||||
"tuple".into(),
|
"tuple".into(),
|
||||||
"str".into(),
|
"str".into(),
|
||||||
"Exception".into(),
|
"Exception".into(),
|
||||||
|
@ -385,11 +386,12 @@ pub fn parse_type_annotation<T>(
|
||||||
let bool_id = ids[3];
|
let bool_id = ids[3];
|
||||||
let virtual_id = ids[4];
|
let virtual_id = ids[4];
|
||||||
let list_id = ids[5];
|
let list_id = ids[5];
|
||||||
let tuple_id = ids[6];
|
let ndarray_id = ids[6];
|
||||||
let str_id = ids[7];
|
let tuple_id = ids[7];
|
||||||
let exn_id = ids[8];
|
let str_id = ids[8];
|
||||||
let uint32_id = ids[9];
|
let exn_id = ids[9];
|
||||||
let uint64_id = ids[10];
|
let uint32_id = ids[10];
|
||||||
|
let uint64_id = ids[11];
|
||||||
|
|
||||||
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| {
|
||||||
if *id == int32_id {
|
if *id == int32_id {
|
||||||
|
@ -460,6 +462,21 @@ pub fn parse_type_annotation<T>(
|
||||||
} else if *id == list_id {
|
} else if *id == list_id {
|
||||||
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
|
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?;
|
||||||
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
Ok(unifier.add_ty(TypeEnum::TList { ty }))
|
||||||
|
} else if *id == ndarray_id {
|
||||||
|
let Tuple { elts, .. } = &slice.node else {
|
||||||
|
return Err(HashSet::from([
|
||||||
|
String::from("Expected 2 type arguments for ndarray"),
|
||||||
|
]))
|
||||||
|
};
|
||||||
|
if elts.len() < 2 {
|
||||||
|
return Err(HashSet::from([
|
||||||
|
String::from("Expected 2 type arguments for ndarray"),
|
||||||
|
]))
|
||||||
|
}
|
||||||
|
|
||||||
|
let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?;
|
||||||
|
let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?;
|
||||||
|
Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims }))
|
||||||
} else if *id == tuple_id {
|
} else if *id == tuple_id {
|
||||||
if let Tuple { elts, .. } = &slice.node {
|
if let Tuple { elts, .. } = &slice.node {
|
||||||
let ty = elts
|
let ty = elts
|
||||||
|
|
|
@ -470,6 +470,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
})),
|
})),
|
||||||
|
{
|
||||||
|
let tvar = primitives.1.get_fresh_var(Some("T".into()), None);
|
||||||
|
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
||||||
|
|
||||||
|
Arc::new(RwLock::new(TopLevelDef::Class {
|
||||||
|
name: "ndarray".into(),
|
||||||
|
object_id: DefinitionId(14),
|
||||||
|
type_vars: vec![tvar.0, ndims.0],
|
||||||
|
fields: Vec::default(),
|
||||||
|
methods: Vec::default(),
|
||||||
|
ancestors: Vec::default(),
|
||||||
|
constructor: None,
|
||||||
|
resolver: None,
|
||||||
|
loc: None,
|
||||||
|
}))
|
||||||
|
},
|
||||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||||
name: "int32".into(),
|
name: "int32".into(),
|
||||||
simple_name: "int32".into(),
|
simple_name: "int32".into(),
|
||||||
|
@ -1265,10 +1281,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
Arc::new(RwLock::new({
|
Arc::new(RwLock::new({
|
||||||
let list_var = primitives.1.get_fresh_var(Some("L".into()), None);
|
let tvar = primitives.1.get_fresh_var(Some("L".into()), None);
|
||||||
let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 });
|
let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 });
|
||||||
|
let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None);
|
||||||
|
let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 });
|
||||||
|
|
||||||
let arg_ty = primitives.1.get_fresh_var_with_range(
|
let arg_ty = primitives.1.get_fresh_var_with_range(
|
||||||
&[list, primitives.0.range],
|
&[list, ndarray, primitives.0.range],
|
||||||
Some("I".into()),
|
Some("I".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
@ -1278,7 +1297,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }],
|
args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }],
|
||||||
ret: int32,
|
ret: int32,
|
||||||
vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)]
|
vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)]
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.collect(),
|
.collect(),
|
||||||
})),
|
})),
|
||||||
|
@ -1296,6 +1315,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
let (start, end, step) = destructure_range(ctx, arg);
|
let (start, end, step) = destructure_range(ctx, arg);
|
||||||
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
|
Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into())
|
||||||
} else {
|
} else {
|
||||||
|
match &*ctx.unifier.get_ty_immutable(arg_ty) {
|
||||||
|
TypeEnum::TList { .. } => {
|
||||||
let int32 = ctx.ctx.i32_type();
|
let int32 = ctx.ctx.i32_type();
|
||||||
let zero = int32.const_zero();
|
let zero = int32.const_zero();
|
||||||
let len = ctx
|
let len = ctx
|
||||||
|
@ -1310,6 +1331,10 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
||||||
} else {
|
} else {
|
||||||
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
TypeEnum::TNDArray { .. } => todo!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
|
|
|
@ -491,11 +491,24 @@ pub fn get_type_from_type_annotation_kinds(
|
||||||
(*name, (subst_ty, *mutability))
|
(*name, (subst_ty, *mutability))
|
||||||
}));
|
}));
|
||||||
let need_subst = !subst.is_empty();
|
let need_subst = !subst.is_empty();
|
||||||
let ty = unifier.add_ty(TypeEnum::TObj {
|
let ty = if obj_id == &DefinitionId(14) {
|
||||||
|
assert_eq!(subst.len(), 2);
|
||||||
|
let tv_tys = subst.iter()
|
||||||
|
.sorted_by_key(|(k, _)| *k)
|
||||||
|
.map(|(_, v)| v)
|
||||||
|
.collect_vec();
|
||||||
|
|
||||||
|
unifier.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: *tv_tys[0],
|
||||||
|
ndims: *tv_tys[1],
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
unifier.add_ty(TypeEnum::TObj {
|
||||||
obj_id: *obj_id,
|
obj_id: *obj_id,
|
||||||
fields: tobj_fields,
|
fields: tobj_fields,
|
||||||
params: subst,
|
params: subst,
|
||||||
});
|
})
|
||||||
|
};
|
||||||
if need_subst {
|
if need_subst {
|
||||||
if let Some(wl) = subst_list.as_mut() {
|
if let Some(wl) = subst_list.as_mut() {
|
||||||
wl.push(ty);
|
wl.push(ty);
|
||||||
|
|
|
@ -223,8 +223,12 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
||||||
if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) {
|
if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) {
|
||||||
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
|
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
|
||||||
} else {
|
} else {
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() });
|
let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) {
|
||||||
self.unify(list, iter.custom.unwrap(), &iter.location)?;
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }),
|
||||||
|
TypeEnum::TNDArray { .. } => todo!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?;
|
||||||
}
|
}
|
||||||
let body =
|
let body =
|
||||||
body.into_iter().map(|b| self.fold_stmt(b)).collect::<Result<Vec<_>, _>>()?;
|
body.into_iter().map(|b| self.fold_stmt(b)).collect::<Result<Vec<_>, _>>()?;
|
||||||
|
@ -1137,9 +1141,13 @@ impl<'a> Inferencer<'a> {
|
||||||
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
|
||||||
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
|
||||||
}
|
}
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
Ok(list)
|
TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }),
|
||||||
|
_ => unreachable!()
|
||||||
|
};
|
||||||
|
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||||
|
Ok(list_like_ty)
|
||||||
}
|
}
|
||||||
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
|
||||||
// the index is a constant, so value can be a sequence.
|
// the index is a constant, so value can be a sequence.
|
||||||
|
@ -1159,10 +1167,15 @@ impl<'a> Inferencer<'a> {
|
||||||
{
|
{
|
||||||
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
|
||||||
}
|
}
|
||||||
|
|
||||||
// the index is not a constant, so value can only be a list
|
// the index is not a constant, so value can only be a list
|
||||||
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
|
||||||
let list = self.unifier.add_ty(TypeEnum::TList { ty });
|
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
|
||||||
self.constrain(value.custom.unwrap(), list, &value.location)?;
|
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
|
||||||
|
TypeEnum::TNDArray { .. } => todo!(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
|
||||||
Ok(ty)
|
Ok(ty)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,6 +159,11 @@ pub enum TypeEnum {
|
||||||
ty: Type,
|
ty: Type,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
TNDArray {
|
||||||
|
ty: Type,
|
||||||
|
ndims: Type,
|
||||||
|
},
|
||||||
|
|
||||||
/// An object type.
|
/// An object type.
|
||||||
TObj {
|
TObj {
|
||||||
/// The [DefintionId] of this object type.
|
/// The [DefintionId] of this object type.
|
||||||
|
@ -193,6 +198,7 @@ impl TypeEnum {
|
||||||
TypeEnum::TLiteral { .. } => "TConstant",
|
TypeEnum::TLiteral { .. } => "TConstant",
|
||||||
TypeEnum::TTuple { .. } => "TTuple",
|
TypeEnum::TTuple { .. } => "TTuple",
|
||||||
TypeEnum::TList { .. } => "TList",
|
TypeEnum::TList { .. } => "TList",
|
||||||
|
TypeEnum::TNDArray { .. } => "TNDArray",
|
||||||
TypeEnum::TObj { .. } => "TObj",
|
TypeEnum::TObj { .. } => "TObj",
|
||||||
TypeEnum::TVirtual { .. } => "TVirtual",
|
TypeEnum::TVirtual { .. } => "TVirtual",
|
||||||
TypeEnum::TCall { .. } => "TCall",
|
TypeEnum::TCall { .. } => "TCall",
|
||||||
|
@ -418,6 +424,9 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => self
|
TypeEnum::TList { ty } => self
|
||||||
.get_instantiations(*ty)
|
.get_instantiations(*ty)
|
||||||
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()),
|
||||||
|
TypeEnum::TNDArray { ty, ndims } => self
|
||||||
|
.get_instantiations(*ty)
|
||||||
|
.map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()),
|
||||||
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| {
|
||||||
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec()
|
||||||
}),
|
}),
|
||||||
|
@ -470,6 +479,7 @@ impl Unifier {
|
||||||
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
|
||||||
TCall { .. } => false,
|
TCall { .. } => false,
|
||||||
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
|
||||||
|
TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars),
|
||||||
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
|
||||||
TObj { params: vars, .. } => {
|
TObj { params: vars, .. } => {
|
||||||
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars))
|
||||||
|
@ -717,7 +727,8 @@ impl Unifier {
|
||||||
self.unify_impl(x, b, false)?;
|
self.unify_impl(x, b, false)?;
|
||||||
self.set_a_to_b(a, x);
|
self.set_a_to_b(a, x);
|
||||||
}
|
}
|
||||||
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => {
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) |
|
||||||
|
(TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => {
|
||||||
for (k, v) in fields {
|
for (k, v) in fields {
|
||||||
match *k {
|
match *k {
|
||||||
RecordKey::Int(_) => {
|
RecordKey::Int(_) => {
|
||||||
|
@ -829,6 +840,15 @@ impl Unifier {
|
||||||
}
|
}
|
||||||
self.set_a_to_b(a, b);
|
self.set_a_to_b(a, b);
|
||||||
}
|
}
|
||||||
|
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
||||||
|
if self.unify_impl(*ty1, *ty2, false).is_err() {
|
||||||
|
return self.incompatible_types(a, b)
|
||||||
|
}
|
||||||
|
if self.unify_impl(*ndims1, *ndims2, false).is_err() {
|
||||||
|
return self.incompatible_types(a, b)
|
||||||
|
}
|
||||||
|
self.set_a_to_b(a, b);
|
||||||
|
}
|
||||||
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
(TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => {
|
||||||
for (k, field) in map {
|
for (k, field) in map {
|
||||||
match *k {
|
match *k {
|
||||||
|
@ -1076,6 +1096,13 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes))
|
||||||
}
|
}
|
||||||
|
TypeEnum::TNDArray { ty, ndims } => {
|
||||||
|
format!(
|
||||||
|
"ndarray[{}, {}]",
|
||||||
|
self.internal_stringify(*ty, obj_to_name, var_to_name, notes),
|
||||||
|
self.internal_stringify(*ndims, obj_to_name, var_to_name, notes),
|
||||||
|
)
|
||||||
|
}
|
||||||
TypeEnum::TVirtual { ty } => {
|
TypeEnum::TVirtual { ty } => {
|
||||||
format!(
|
format!(
|
||||||
"virtual[{}]",
|
"virtual[{}]",
|
||||||
|
@ -1195,7 +1222,7 @@ impl Unifier {
|
||||||
// variables, i.e. things like TRecord, TCall should not occur, and we
|
// variables, i.e. things like TRecord, TCall should not occur, and we
|
||||||
// should be safe to not implement the substitution for those variants.
|
// should be safe to not implement the substitution for those variants.
|
||||||
match &*ty {
|
match &*ty {
|
||||||
TypeEnum::TRigidVar { .. } => None,
|
TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None,
|
||||||
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
|
TypeEnum::TVar { id, .. } => mapping.get(id).copied(),
|
||||||
TypeEnum::TTuple { ty } => {
|
TypeEnum::TTuple { ty } => {
|
||||||
let mut new_ty = Cow::from(ty);
|
let mut new_ty = Cow::from(ty);
|
||||||
|
@ -1213,6 +1240,19 @@ impl Unifier {
|
||||||
TypeEnum::TList { ty } => {
|
TypeEnum::TList { ty } => {
|
||||||
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t }))
|
||||||
}
|
}
|
||||||
|
TypeEnum::TNDArray { ty, ndims } => {
|
||||||
|
let new_ty = self.subst_impl(*ty, mapping, cache);
|
||||||
|
let new_ndims = self.subst_impl(*ndims, mapping, cache);
|
||||||
|
|
||||||
|
if new_ty.is_some() || new_ndims.is_some() {
|
||||||
|
Some(self.add_ty(TypeEnum::TNDArray {
|
||||||
|
ty: new_ty.unwrap_or(*ty),
|
||||||
|
ndims: new_ndims.unwrap_or(*ndims)
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
TypeEnum::TVirtual { ty } => self
|
TypeEnum::TVirtual { ty } => self
|
||||||
.subst_impl(*ty, mapping, cache)
|
.subst_impl(*ty, mapping, cache)
|
||||||
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
.map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })),
|
||||||
|
@ -1383,6 +1423,19 @@ impl Unifier {
|
||||||
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
(TList { ty: ty1 }, TList { ty: ty2 }) => {
|
||||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty })))
|
||||||
}
|
}
|
||||||
|
(TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => {
|
||||||
|
let ty = self.get_intersection(*ty1, *ty2)?;
|
||||||
|
let ndims = self.get_intersection(*ndims1, *ndims2)?;
|
||||||
|
|
||||||
|
Ok(if ty.is_some() || ndims.is_some() {
|
||||||
|
Some(self.add_ty(TNDArray {
|
||||||
|
ty: ty.unwrap_or(*ty1),
|
||||||
|
ndims: ndims.unwrap_or(*ndims1),
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
})
|
||||||
|
}
|
||||||
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
(TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => {
|
||||||
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty })))
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,6 +33,7 @@ impl Unifier {
|
||||||
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
&& ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2))
|
||||||
}
|
}
|
||||||
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
(TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 })
|
||||||
|
| (TypeEnum::TNDArray { ty: ty1 }, TypeEnum::TNDArray { ty: ty2 })
|
||||||
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
| (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => {
|
||||||
self.eq(*ty1, *ty2)
|
self.eq(*ty1, *ty2)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue