diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 142fe51..98f51bd 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -400,6 +400,9 @@ fn gen_rpc_tag( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } + TNDArray { .. } => { + todo!() + } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } @@ -673,6 +676,14 @@ pub fn attributes_writeback( values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); } }, + TypeEnum::TNDArray { ty: elem_ty, .. } => { + if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + host_attributes.append(pydict)?; + values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); + } + }, _ => {} } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9a0ad86..7a1183d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -85,6 +85,7 @@ pub struct PrimitivePythonId { float64: u64, bool: u64, list: u64, + ndarray: u64, tuple: u64, typevar: u64, const_generic_marker: u64, @@ -879,6 +880,7 @@ impl Nac3 { float: get_attr_id(builtins_mod, "float"), float64: get_attr_id(numpy_mod, "float64"), list: get_attr_id(builtins_mod, "list"), + ndarray: get_attr_id(numpy_mod, "NDArray"), tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 631e2de..aeb99fa 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -302,6 +302,12 @@ impl InnerResolver { let var = unifier.get_dummy_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Ok((list, false))) + } else if ty_id == self.primitive_ids.ndarray { + // do not handle type var param and concrete check here + let var = unifier.get_dummy_var().0; + let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0; + let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims }); + Ok(Ok((ndarray, false))) } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) @@ -446,6 +452,16 @@ impl InnerResolver { ))); } } + TypeEnum::TNDArray { .. } => { + if args.len() != 2 { + return Ok(Err(format!( + "type list needs exactly 2 type parameters, found {}", + args.len() + ))); + } + + todo!() + } TypeEnum::TTuple { .. } => { let args = match args .iter() @@ -607,7 +623,7 @@ impl InnerResolver { Err(e) => return Ok(Err(e)), }; match (&*unifier.get_ty(extracted_ty), inst_check) { - // do the instantiation for these three types + // do the instantiation for these four types (TypeEnum::TList { ty }, false) => { let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { @@ -632,6 +648,30 @@ impl InnerResolver { } } } + (TypeEnum::TNDArray { ty, ndims }, false) => { + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(*ty), + TypeEnum::TVar { fields: None, range, .. } + if range.is_empty() + )); + Ok(Ok(extracted_ty)) + } else { + let actual_ty = + self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; + match actual_ty { + Ok(t) => match unifier.unify(*ty, t) { + Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))), + Err(e) => Ok(Err(format!( + "type error ({}) for the ndarray", + e.to_display(unifier).to_string() + ))), + }, + Err(e) => Ok(Err(e)), + } + } + } (TypeEnum::TTuple { .. }, false) => { let elements: &PyTuple = obj.downcast()?; let types: Result, _>, _> = elements @@ -898,6 +938,8 @@ impl InnerResolver { global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.ndarray { + todo!() } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 7745160..a440276 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -47,6 +47,10 @@ pub enum ConcreteTypeEnum { TList { ty: ConcreteType, }, + TNDArray { + ty: ConcreteType, + ndims: ConcreteType, + }, TObj { obj_id: DefinitionId, fields: HashMap, @@ -167,6 +171,10 @@ impl ConcreteTypeStore { TypeEnum::TList { ty } => ConcreteTypeEnum::TList { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, + TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray { + ty: self.from_unifier_type(unifier, primitives, *ty, cache), + ndims: self.from_unifier_type(unifier, primitives, *ndims, cache), + }, TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { obj_id: *obj_id, fields: fields @@ -260,6 +268,12 @@ impl ConcreteTypeStore { ConcreteTypeEnum::TList { ty } => { TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } + ConcreteTypeEnum::TNDArray { ty, ndims } => { + TypeEnum::TNDArray { + ty: self.to_unifier_type(unifier, primitives, *ty, cache), + ndims: self.to_unifier_type(unifier, primitives, *ndims, cache), + } + } ConcreteTypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3a894a4..79ebe4a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1846,6 +1846,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.build_gep_and_load(arr_ptr, &[index], None).into() } } + TypeEnum::TNDArray { .. } => { + return Err(String::from("subscript operator for ndarray not implemented")) + } TypeEnum::TTuple { .. } => { let index: u32 = if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 41bd2a8..21943d4 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -507,6 +507,24 @@ fn get_llvm_type<'ctx>( ]; ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() } + TNDArray { ty, .. } => { + let llvm_usize = generator.get_size_type(ctx); + let element_type = get_llvm_type( + ctx, module, generator, unifier, top_level, type_cache, primitives, *ty, + ); + + // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // + // * num_dims: Number of dimensions in the array + // * dims: Pointer to an array containing the size of each dimension + // * data: Pointer to an array containing the array data + let fields = [ + llvm_usize.into(), + llvm_usize.ptr_type(AddressSpace::default()).into(), + element_type.ptr_type(AddressSpace::default()).into(), + ]; + ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() + } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), }; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f15c2b9..e53ea32 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -99,63 +99,69 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } } ExprKind::Subscript { value, slice, .. } => { - assert!(matches!( - ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref(), - TypeEnum::TList { .. }, - )); - let i32_type = ctx.ctx.i32_type(); - let zero = i32_type.const_zero(); - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() - } else { - return Ok(None) - }; - let len = ctx - .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) - .into_int_value(); - let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() - } else { - return Ok(None) - }; - let raw_index = ctx.builder.build_int_s_extend( - raw_index, - generator.get_size_type(ctx.ctx), - "sext", - ); - // handle negative index - let is_negative = ctx.builder.build_int_compare( - IntPredicate::SLT, - raw_index, - generator.get_size_type(ctx.ctx).const_zero(), - "is_neg", - ); - let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); - let index = ctx - .builder - .build_select(is_negative, adjusted, raw_index, "index") - .into_int_value(); - // unsigned less than is enough, because negative index after adjustment is - // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder.build_int_compare( - IntPredicate::ULT, - index, - len, - "inbound", - ); - ctx.make_assert( - generator, - bound_check, - "0:IndexError", - "index {0} out of bounds 0:{1}", - [Some(raw_index), Some(len), None], - slice.location, - ); - unsafe { - let arr_ptr = ctx - .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) - .into_pointer_value(); - ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) + match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { + TypeEnum::TList { .. } => { + let i32_type = ctx.ctx.i32_type(); + let zero = i32_type.const_zero(); + let v = generator + .gen_expr(ctx, value)? + .unwrap() + .to_basic_value_enum(ctx, generator, value.custom.unwrap())? + .into_pointer_value(); + let len = ctx + .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) + .into_int_value(); + let raw_index = generator + .gen_expr(ctx, slice)? + .unwrap() + .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? + .into_int_value(); + let raw_index = ctx.builder.build_int_s_extend( + raw_index, + generator.get_size_type(ctx.ctx), + "sext", + ); + // handle negative index + let is_negative = ctx.builder.build_int_compare( + IntPredicate::SLT, + raw_index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ); + let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); + let index = ctx + .builder + .build_select(is_negative, adjusted, raw_index, "index") + .into_int_value(); + // unsigned less than is enough, because negative index after adjustment is + // bigger than the length (for unsigned cmp) + let bound_check = ctx.builder.build_int_compare( + IntPredicate::ULT, + index, + len, + "inbound", + ); + ctx.make_assert( + generator, + bound_check, + "0:IndexError", + "index {0} out of bounds 0:{1}", + [Some(raw_index), Some(len), None], + slice.location, + ); + unsafe { + let arr_ptr = ctx + .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) + .into_pointer_value(); + ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) + } + } + + TypeEnum::TNDArray { .. } => { + todo!() + } + + _ => unreachable!(), } } _ => unreachable!(), @@ -203,7 +209,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = value .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); - let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else { + let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { unreachable!() }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 0932bea..3ff5e0e 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -354,13 +354,14 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 11] = [ + static IDENTIFIER_ID: [StrRef; 12] = [ "int32".into(), "int64".into(), "float".into(), "bool".into(), "virtual".into(), "list".into(), + "ndarray".into(), "tuple".into(), "str".into(), "Exception".into(), @@ -385,11 +386,12 @@ pub fn parse_type_annotation( let bool_id = ids[3]; let virtual_id = ids[4]; let list_id = ids[5]; - let tuple_id = ids[6]; - let str_id = ids[7]; - let exn_id = ids[8]; - let uint32_id = ids[9]; - let uint64_id = ids[10]; + let ndarray_id = ids[6]; + let tuple_id = ids[7]; + let str_id = ids[8]; + let exn_id = ids[9]; + let uint32_id = ids[10]; + let uint64_id = ids[11]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { @@ -460,6 +462,21 @@ pub fn parse_type_annotation( } else if *id == list_id { let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; Ok(unifier.add_ty(TypeEnum::TList { ty })) + } else if *id == ndarray_id { + let Tuple { elts, .. } = &slice.node else { + return Err(HashSet::from([ + String::from("Expected 2 type arguments for ndarray"), + ])) + }; + if elts.len() < 2 { + return Err(HashSet::from([ + String::from("Expected 2 type arguments for ndarray"), + ])) + } + + let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?; + let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?; + Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims })) } else if *id == tuple_id { if let Tuple { elts, .. } = &slice.node { let ty = elts diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 0b8ccde..00da280 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -470,6 +470,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + { + let tvar = primitives.1.get_fresh_var(Some("T".into()), None); + let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None); + + Arc::new(RwLock::new(TopLevelDef::Class { + name: "ndarray".into(), + object_id: DefinitionId(14), + type_vars: vec![tvar.0, ndims.0], + fields: Vec::default(), + methods: Vec::default(), + ancestors: Vec::default(), + constructor: None, + resolver: None, + loc: None, + })) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), @@ -1265,10 +1281,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { }), ), Arc::new(RwLock::new({ - let list_var = primitives.1.get_fresh_var(Some("L".into()), None); - let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); + let tvar = primitives.1.get_fresh_var(Some("L".into()), None); + let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 }); + let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None); + let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 }); + let arg_ty = primitives.1.get_fresh_var_with_range( - &[list, primitives.0.range], + &[list, ndarray, primitives.0.range], Some("I".into()), None, ); @@ -1278,7 +1297,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], ret: int32, - vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)] + vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)] .into_iter() .collect(), })), @@ -1296,19 +1315,25 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let (start, end, step) = destructure_range(ctx, arg); Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) } else { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, int32.const_int(1, false)], - None, - ) - .into_int_value(); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) + match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TList { .. } => { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let len = ctx + .build_gep_and_load( + arg.into_pointer_value(), + &[zero, int32.const_int(1, false)], + None, + ) + .into_int_value(); + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) + } + } + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), } }) }, diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4482bc5..94897ff 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -491,11 +491,24 @@ pub fn get_type_from_type_annotation_kinds( (*name, (subst_ty, *mutability)) })); let need_subst = !subst.is_empty(); - let ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - fields: tobj_fields, - params: subst, - }); + let ty = if obj_id == &DefinitionId(14) { + assert_eq!(subst.len(), 2); + let tv_tys = subst.iter() + .sorted_by_key(|(k, _)| *k) + .map(|(_, v)| v) + .collect_vec(); + + unifier.add_ty(TypeEnum::TNDArray { + ty: *tv_tys[0], + ndims: *tv_tys[1], + }) + } else { + unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields, + params: subst, + }) + }; if need_subst { if let Some(wl) = subst_list.as_mut() { wl.push(ty); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7ff3e0b..f9d3396 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -223,8 +223,12 @@ impl<'a> Fold<()> for Inferencer<'a> { if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) { self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; } else { - let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - self.unify(list, iter.custom.unwrap(), &iter.location)?; + let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }), + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), + }; + self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?; } let body = body.into_iter().map(|b| self.fold_stmt(b)).collect::, _>>()?; @@ -1137,9 +1141,13 @@ impl<'a> Inferencer<'a> { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; } - let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.constrain(value.custom.unwrap(), list, &value.location)?; - Ok(list) + let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), + TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }), + _ => unreachable!() + }; + self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; + Ok(list_like_ty) } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { // the index is a constant, so value can be a sequence. @@ -1159,10 +1167,15 @@ impl<'a> Inferencer<'a> { { return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) } + // the index is not a constant, so value can only be a list self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; - let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.constrain(value.custom.unwrap(), list, &value.location)?; + let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), + }; + self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; Ok(ty) } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 04a905d..e0a72e8 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -159,6 +159,11 @@ pub enum TypeEnum { ty: Type, }, + TNDArray { + ty: Type, + ndims: Type, + }, + /// An object type. TObj { /// The [DefintionId] of this object type. @@ -193,6 +198,7 @@ impl TypeEnum { TypeEnum::TLiteral { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", + TypeEnum::TNDArray { .. } => "TNDArray", TypeEnum::TObj { .. } => "TObj", TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", @@ -418,6 +424,9 @@ impl Unifier { TypeEnum::TList { ty } => self .get_instantiations(*ty) .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()), + TypeEnum::TNDArray { ty, ndims } => self + .get_instantiations(*ty) + .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()), TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() }), @@ -470,6 +479,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), + TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) @@ -717,7 +727,8 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => { + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) | + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => { for (k, v) in fields { match *k { RecordKey::Int(_) => { @@ -829,6 +840,15 @@ impl Unifier { } self.set_a_to_b(a, b); } + (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { + if self.unify_impl(*ty1, *ty2, false).is_err() { + return self.incompatible_types(a, b) + } + if self.unify_impl(*ndims1, *ndims2, false).is_err() { + return self.incompatible_types(a, b) + } + self.set_a_to_b(a, b); + } (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { for (k, field) in map { match *k { @@ -1076,6 +1096,13 @@ impl Unifier { TypeEnum::TList { ty } => { format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } + TypeEnum::TNDArray { ty, ndims } => { + format!( + "ndarray[{}, {}]", + self.internal_stringify(*ty, obj_to_name, var_to_name, notes), + self.internal_stringify(*ndims, obj_to_name, var_to_name, notes), + ) + } TypeEnum::TVirtual { ty } => { format!( "virtual[{}]", @@ -1195,7 +1222,7 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty } => { let mut new_ty = Cow::from(ty); @@ -1213,6 +1240,19 @@ impl Unifier { TypeEnum::TList { ty } => { self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t })) } + TypeEnum::TNDArray { ty, ndims } => { + let new_ty = self.subst_impl(*ty, mapping, cache); + let new_ndims = self.subst_impl(*ndims, mapping, cache); + + if new_ty.is_some() || new_ndims.is_some() { + Some(self.add_ty(TypeEnum::TNDArray { + ty: new_ty.unwrap_or(*ty), + ndims: new_ndims.unwrap_or(*ndims) + })) + } else { + None + } + } TypeEnum::TVirtual { ty } => self .subst_impl(*ty, mapping, cache) .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), @@ -1383,6 +1423,19 @@ impl Unifier { (TList { ty: ty1 }, TList { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) } + (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { + let ty = self.get_intersection(*ty1, *ty2)?; + let ndims = self.get_intersection(*ndims1, *ndims2)?; + + Ok(if ty.is_some() || ndims.is_some() { + Some(self.add_ty(TNDArray { + ty: ty.unwrap_or(*ty1), + ndims: ndims.unwrap_or(*ndims1), + })) + } else { + None + }) + } (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 3069b57..eb44c22 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -33,6 +33,7 @@ impl Unifier { && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) } (TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) + | (TypeEnum::TNDArray { ty: ty1 }, TypeEnum::TNDArray { ty: ty2 }) | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { self.eq(*ty1, *ty2) }