From 234a6bde2a67f5451fa0669121abd6d0d3e517e7 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 27 Feb 2024 13:39:05 +0800 Subject: [PATCH] core: Use TObj for NDArray --- nac3artiq/src/codegen.rs | 11 -- nac3artiq/src/symbol_resolver.rs | 29 +++-- nac3core/src/codegen/concrete_type.rs | 14 --- nac3core/src/codegen/expr.rs | 38 ++++--- nac3core/src/codegen/mod.rs | 80 +++++++------- nac3core/src/codegen/stmt.rs | 19 +++- nac3core/src/symbol_resolver.rs | 39 ++----- nac3core/src/toplevel/builtins.rs | 24 ++--- nac3core/src/toplevel/helper.rs | 19 ++++ nac3core/src/toplevel/numpy.rs | 74 +++++++++++-- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/toplevel/type_annotation.rs | 23 +--- nac3core/src/typecheck/type_inferencer/mod.rs | 101 +++++++++++------- .../src/typecheck/type_inferencer/test.rs | 12 +++ nac3core/src/typecheck/typedef/mod.rs | 78 +------------- nac3core/src/typecheck/typedef/test.rs | 2 +- 20 files changed, 302 insertions(+), 275 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 1f8feed0..1b338aaf 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -397,9 +397,6 @@ fn gen_rpc_tag( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } - TNDArray { .. } => { - todo!() - } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } @@ -660,14 +657,6 @@ pub fn attributes_writeback( values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); } }, - TypeEnum::TNDArray { ty: elem_ty, .. } => { - if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { - let pydict = PyDict::new(py); - pydict.set_item("obj", val)?; - host_attributes.append(pydict)?; - values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); - } - }, _ => {} } } diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 3de2bdcd..6c1130d0 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -2,7 +2,12 @@ use inkwell::{types::BasicType, values::BasicValueEnum, AddressSpace}; use nac3core::{ codegen::{CodeGenContext, CodeGenerator}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, - toplevel::{DefinitionId, TopLevelDef}, + toplevel::{ + DefinitionId, + helper::PRIMITIVE_DEF_IDS, + numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + TopLevelDef, + }, typecheck::{ type_inferencer::PrimitiveStore, typedef::{Type, TypeEnum, Unifier}, @@ -306,7 +311,7 @@ impl InnerResolver { // do not handle type var param and concrete check here let var = unifier.get_dummy_var().0; let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0; - let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims }); + let ndarray = make_ndarray_ty(unifier, primitives, Some(var), Some(ndims)); Ok(Ok((ndarray, false))) } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here @@ -452,7 +457,7 @@ impl InnerResolver { ))); } } - TypeEnum::TNDArray { .. } => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { if args.len() != 2 { return Ok(Err(format!( "type list needs exactly 2 type parameters, found {}", @@ -648,11 +653,12 @@ impl InnerResolver { } } } - (TypeEnum::TNDArray { ty, ndims }, false) => { + (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { assert!(matches!( - &*unifier.get_ty(*ty), + &*unifier.get_ty(ty), TypeEnum::TVar { fields: None, range, .. } if range.is_empty() )); @@ -661,8 +667,17 @@ impl InnerResolver { let actual_ty = self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; match actual_ty { - Ok(t) => match unifier.unify(*ty, t) { - Ok(()) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))), + Ok(t) => match unifier.unify(ty, t) { + Ok(()) => { + let ndarray_ty = make_ndarray_ty( + unifier, + primitives, + Some(ty), + Some(ndims), + ); + + Ok(Ok(ndarray_ty)) + } Err(e) => Ok(Err(format!( "type error ({}) for the ndarray", e.to_display(unifier), diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index a440276c..77451600 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -47,10 +47,6 @@ pub enum ConcreteTypeEnum { TList { ty: ConcreteType, }, - TNDArray { - ty: ConcreteType, - ndims: ConcreteType, - }, TObj { obj_id: DefinitionId, fields: HashMap, @@ -171,10 +167,6 @@ impl ConcreteTypeStore { TypeEnum::TList { ty } => ConcreteTypeEnum::TList { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, - TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray { - ty: self.from_unifier_type(unifier, primitives, *ty, cache), - ndims: self.from_unifier_type(unifier, primitives, *ndims, cache), - }, TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { obj_id: *obj_id, fields: fields @@ -268,12 +260,6 @@ impl ConcreteTypeStore { ConcreteTypeEnum::TList { ty } => { TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } - ConcreteTypeEnum::TNDArray { ty, ndims } => { - TypeEnum::TNDArray { - ty: self.to_unifier_type(unifier, primitives, *ty, cache), - ndims: self.to_unifier_type(unifier, primitives, *ndims, cache), - } - } ConcreteTypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e0df6f0b..360d8734 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -13,7 +13,12 @@ use crate::{ CodeGenContext, CodeGenTask, }, symbol_resolver::{SymbolValue, ValueEnum}, - toplevel::{DefinitionId, TopLevelDef}, + toplevel::{ + DefinitionId, + helper::PRIMITIVE_DEF_IDS, + numpy::make_ndarray_ty, + TopLevelDef, + }, typecheck::{ typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, magic_methods::{binop_name, binop_assign_name}, @@ -181,7 +186,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &mut self.unifier, self.top_level, &mut self.type_cache, - &self.primitives, ty, ) } @@ -1204,23 +1208,25 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>( SymbolValue::U64(v) => Ok(v), SymbolValue::U32(v) => Ok(v as u64), SymbolValue::I32(v) => u64::try_from(v) - .map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")), + .map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), SymbolValue::I64(v) => u64::try_from(v) - .map_err(|_| format!("Expected non-negative literal for TNDArray.ndims, got {v}")), + .map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")), _ => unreachable!(), }) .collect::, _>>()?; assert!(!ndims.is_empty()); - let ndarray_ty_enum = TypeEnum::TNDArray { - ty, - ndims: ctx.unifier.get_fresh_literal( - ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), - None, - ), - }; - let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + let ndarray_ndims_ty = ctx.unifier.get_fresh_literal( + ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), + None, + ); + let ndarray_ty = make_ndarray_ty( + &mut ctx.unifier, + &ctx.primitives, + Some(ty), + Some(ndarray_ndims_ty), + ); let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); let llvm_ndarray_data_t = ctx.get_llvm_type(generator, ty).as_basic_type_enum(); @@ -1963,7 +1969,13 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( v.get_data().get(ctx, generator, index, None).into() } } - TypeEnum::TNDArray { ty, ndims } => { + TypeEnum::TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let (ty, ndims) = params.iter() + .sorted_by_key(|(var_id, _)| *var_id) + .map(|(_, ty)| ty) + .collect_tuple() + .unwrap(); + let v = if let Some(v) = generator.gen_expr(ctx, value)? { v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() } else { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index bcc1f4be..01048790 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,6 +1,11 @@ use crate::{ symbol_resolver::{StaticValue, SymbolResolver}, - toplevel::{TopLevelContext, TopLevelDef}, + toplevel::{ + helper::PRIMITIVE_DEF_IDS, + numpy::unpack_ndarray_tvars, + TopLevelContext, + TopLevelDef, + }, typecheck::{ type_inferencer::{CodeLocation, PrimitiveStore}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, @@ -417,7 +422,6 @@ fn get_llvm_type<'ctx>( unifier: &mut Unifier, top_level: &TopLevelContext, type_cache: &mut HashMap>, - primitives: &PrimitiveStore, ty: Type, ) -> BasicTypeEnum<'ctx> { use TypeEnum::*; @@ -427,28 +431,50 @@ fn get_llvm_type<'ctx>( let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { TObj { obj_id, fields, .. } => { - // check to avoid treating primitives other than Option as classes - if obj_id.0 <= 10 { - match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) - { - ( - TObj { obj_id, params, .. }, - TObj { obj_id: opt_id, .. }, - ) if *obj_id == *opt_id => { - return get_llvm_type( + // check to avoid treating non-class primitives as classes + if obj_id.0 <= PRIMITIVE_DEF_IDS.max_id().0 { + return match &*unifier.get_ty_immutable(ty) { + TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.option => { + get_llvm_type( ctx, module, generator, unifier, top_level, type_cache, - primitives, *params.iter().next().unwrap().1, ) .ptr_type(AddressSpace::default()) - .into(); + .into() } - _ => unreachable!("must be option type"), + + TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let llvm_usize = generator.get_size_type(ctx); + let (dtype, _) = unpack_ndarray_tvars(unifier, ty); + let element_type = get_llvm_type( + ctx, + module, + generator, + unifier, + top_level, + type_cache, + dtype, + ); + + // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // + // * num_dims: Number of dimensions in the array + // * dims: Pointer to an array containing the size of each dimension + // * data: Pointer to an array containing the array data + let fields = [ + llvm_usize.into(), + llvm_usize.ptr_type(AddressSpace::default()).into(), + element_type.ptr_type(AddressSpace::default()).into(), + ]; + ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() + } + + _ => unreachable!("LLVM type for primitive {} is missing", unifier.stringify(ty)), } } // a struct with fields in the order of declaration @@ -477,7 +503,6 @@ fn get_llvm_type<'ctx>( unifier, top_level, type_cache, - primitives, fields[&f.0].0, ) }) @@ -493,7 +518,7 @@ fn get_llvm_type<'ctx>( .iter() .map(|ty| { get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, primitives, *ty, + ctx, module, generator, unifier, top_level, type_cache, *ty, ) }) .collect_vec(); @@ -502,7 +527,7 @@ fn get_llvm_type<'ctx>( TList { ty } => { // a struct with an integer and a pointer to an array let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, primitives, *ty, + ctx, module, generator, unifier, top_level, type_cache, *ty, ); let fields = [ element_type.ptr_type(AddressSpace::default()).into(), @@ -510,24 +535,6 @@ fn get_llvm_type<'ctx>( ]; ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() } - TNDArray { ty, .. } => { - let llvm_usize = generator.get_size_type(ctx); - let element_type = get_llvm_type( - ctx, module, generator, unifier, top_level, type_cache, primitives, *ty, - ); - - // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } - // - // * num_dims: Number of dimensions in the array - // * dims: Pointer to an array containing the size of each dimension - // * data: Pointer to an array containing the array data - let fields = [ - llvm_usize.into(), - llvm_usize.ptr_type(AddressSpace::default()).into(), - element_type.ptr_type(AddressSpace::default()).into(), - ]; - ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() - } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), }; @@ -561,7 +568,7 @@ fn get_llvm_abi_type<'ctx>( return if unifier.unioned(ty, primitives.bool) { ctx.bool_type().into() } else { - get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, primitives, ty) + get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) } } @@ -763,7 +770,6 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte &mut unifier, top_level_ctx.as_ref(), &mut type_cache, - &primitives, arg.ty, ); let alloca = builder diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index ef1ec629..d7a3d6f5 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -10,7 +10,12 @@ use crate::{ expr::gen_binop_expr, gen_in_range_check, }, - toplevel::{DefinitionId, TopLevelDef}, + toplevel::{ + DefinitionId, + helper::PRIMITIVE_DEF_IDS, + numpy::unpack_ndarray_tvars, + TopLevelDef, + }, typecheck::typedef::{FunSignature, Type, TypeEnum}, }; use inkwell::{ @@ -186,7 +191,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( v.get_data().ptr_offset(ctx, generator, index, name) } - TypeEnum::TNDArray { .. } => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { todo!() } @@ -242,11 +247,15 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); let value = ListValue::from_ptr_val(value, llvm_usize, None); - let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { - unreachable!() + let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { + TypeEnum::TList { ty } => *ty, + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0 + } + _ => unreachable!(), }; - let ty = ctx.get_llvm_type(generator, *ty); + let ty = ctx.get_llvm_type(generator, ty); let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else { return Ok(()) }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index a2a663f5..74a79d3f 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -3,16 +3,12 @@ use std::sync::Arc; use std::{collections::HashMap, collections::HashSet, fmt::Display}; use std::rc::Rc; -use crate::typecheck::typedef::TypeEnum; use crate::{ - codegen::CodeGenContext, + codegen::{CodeGenContext, CodeGenerator}, toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation}, -}; -use crate::{ - codegen::CodeGenerator, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{Type, Unifier}, + typedef::{Type, TypeEnum, Unifier}, }, }; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; @@ -353,14 +349,13 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 13] = [ + static IDENTIFIER_ID: [StrRef; 12] = [ "int32".into(), "int64".into(), "float".into(), "bool".into(), "virtual".into(), "list".into(), - "ndarray".into(), "tuple".into(), "str".into(), "Exception".into(), @@ -386,13 +381,12 @@ pub fn parse_type_annotation( let bool_id = ids[3]; let virtual_id = ids[4]; let list_id = ids[5]; - let ndarray_id = ids[6]; - let tuple_id = ids[7]; - let str_id = ids[8]; - let exn_id = ids[9]; - let uint32_id = ids[10]; - let uint64_id = ids[11]; - let literal_id = ids[12]; + let tuple_id = ids[6]; + let str_id = ids[7]; + let exn_id = ids[8]; + let uint32_id = ids[9]; + let uint64_id = ids[10]; + let literal_id = ids[11]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { @@ -463,21 +457,6 @@ pub fn parse_type_annotation( } else if *id == list_id { let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; Ok(unifier.add_ty(TypeEnum::TList { ty })) - } else if *id == ndarray_id { - let Tuple { elts, .. } = &slice.node else { - return Err(HashSet::from([ - String::from("Expected 2 type arguments for ndarray"), - ])) - }; - if elts.len() < 2 { - return Err(HashSet::from([ - String::from("Expected 2 type arguments for ndarray"), - ])) - } - - let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?; - let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?; - Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims })) } else if *id == tuple_id { if let Tuple { elts, .. } = &slice.node { let ty = elts diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index ed526177..635145d9 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -274,14 +274,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; - let ndarray = { - let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0); - primitives.1.add_ty(ndarray_ty) - }; - let ndarray_float = { - let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); - primitives.1.add_ty(ndarray_ty_enum) - }; + let ndarray = primitives.0.ndarray; + let ndarray_float = make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), None); let ndarray_float_2d = { let value = match primitives.0.size_t { 64 => SymbolValue::U64(2u64), @@ -293,10 +287,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { loc: None, }); - primitives.1.add_ty(TypeEnum::TNDArray { - ty: float, - ndims, - }) + make_ndarray_ty(&mut primitives.1, &primitives.0, Some(float), Some(ndims)) }; let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = primitives.1.get_fresh_var_with_range( @@ -1352,7 +1343,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let tvar = primitives.1.get_fresh_var(Some("L".into()), None); let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 }); let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None); - let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 }); + let ndarray = make_ndarray_ty( + &mut primitives.1, + &primitives.0, + Some(tvar.0), + Some(ndims.0), + ); let arg_ty = primitives.1.get_fresh_var_with_range( &[list, ndarray, primitives.0.range], @@ -1404,7 +1400,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ) } } - TypeEnum::TNDArray { .. } => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let llvm_i32 = ctx.ctx.i32_type(); let i32_zero = llvm_i32.const_zero(); diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 3d948e25..c07c0ab4 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::typecheck::typedef::Mapping; use nac3parser::ast::{Constant, Location}; use super::*; @@ -194,6 +195,23 @@ impl TopLevelComposer { params: HashMap::from([(option_type_var.1, option_type_var.0)]), }); + let size_t_ty = match size_t { + 32 => uint32, + 64 => uint64, + _ => unreachable!(), + }; + + let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); + let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray = unifier.add_ty(TypeEnum::TObj { + obj_id: PRIMITIVE_DEF_IDS.ndarray, + fields: Mapping::new(), + params: Mapping::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]) + }); + let primitives = PrimitiveStore { int32, int64, @@ -206,6 +224,7 @@ impl TopLevelComposer { str, exception, option, + ndarray, size_t, }; unifier.put_primitive_store(&primitives); diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index a823f4fb..2f60930a 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,5 +1,6 @@ use inkwell::{IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; use inkwell::values::{AggregateValueEnum, ArrayValue, IntValue}; +use itertools::Itertools; use nac3parser::ast::StrRef; use crate::{ codegen::{ @@ -15,10 +16,68 @@ use crate::{ stmt::gen_for_callback }, symbol_resolver::ValueEnum, - toplevel::DefinitionId, - typecheck::typedef::{FunSignature, Type, TypeEnum}, + toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS}, + typecheck::{ + type_inferencer::PrimitiveStore, + typedef::{FunSignature, Mapping, Type, TypeEnum, Unifier}, + }, }; +/// Creates a `ndarray` [`Type`] with the given type arguments. +/// +/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +pub fn make_ndarray_ty( + unifier: &mut Unifier, + primitives: &PrimitiveStore, + dtype: Option, + ndims: Option, +) -> Type { + let ndarray = primitives.ndarray; + + let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { + panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) + }; + debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + + let tvar_ids = params.iter() + .map(|(obj_id, _)| *obj_id) + .sorted() + .collect_vec(); + debug_assert_eq!(tvar_ids.len(), 2); + + let mut tvar_subst = Mapping::new(); + if let Some(dtype) = dtype { + tvar_subst.insert(tvar_ids[0], dtype); + } + if let Some(ndims) = ndims { + tvar_subst.insert(tvar_ids[1], ndims); + } + + unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) +} + +/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to +/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. +pub fn unpack_ndarray_tvars( + unifier: &mut Unifier, + ndarray: Type, +) -> (Type, Type) { + let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { + panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) + }; + debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + debug_assert_eq!(params.len(), 2); + + params.iter() + .sorted_by_key(|(obj_id, _)| *obj_id) + .map(|(_, ty)| *ty) + .collect_tuple() + .unwrap() +} + /// Creates an `NDArray` instance from a constant shape. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -29,8 +88,7 @@ fn create_ndarray_const_shape<'ctx>( elem_ty: Type, shape: ArrayValue<'ctx> ) -> Result, String> { - let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); - let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + let ndarray_ty = make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(elem_ty), None); let llvm_usize = generator.get_size_type(ctx.ctx); @@ -147,8 +205,12 @@ fn call_ndarray_empty_impl<'ctx>( elem_ty: Type, shape: ListValue<'ctx>, ) -> Result, String> { - let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); - let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + let ndarray_ty = make_ndarray_ty( + &mut ctx.unifier, + &ctx.primitives, + Some(elem_ty), + None, + ); let llvm_i32 = ctx.ctx.i32_type(); let llvm_usize = generator.get_size_type(ctx.ctx); diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 5ac6ce4f..75a7ba11 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [29]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [28]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index dbd0cc79..e87a7331 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar18]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar18\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar17]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar17\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index a4b35ac4..4b37f03c 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [31]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [36]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [30]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [35]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 86be6ed0..d194a053 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar17, typevar18]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar17\", \"typevar18\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar16, typevar17]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"typevar16\", \"typevar17\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index c2fe4836..5508eef3 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [37]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [36]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [45]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [44]\n}\n", ] diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 3ae2e98d..2086cfe5 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -492,24 +492,11 @@ pub fn get_type_from_type_annotation_kinds( (*name, (subst_ty, *mutability)) })); let need_subst = !subst.is_empty(); - let ty = if obj_id == &PRIMITIVE_DEF_IDS.ndarray { - assert_eq!(subst.len(), 2); - let tv_tys = subst.iter() - .sorted_by_key(|(k, _)| *k) - .map(|(_, v)| v) - .collect_vec(); - - unifier.add_ty(TypeEnum::TNDArray { - ty: *tv_tys[0], - ndims: *tv_tys[1], - }) - } else { - unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - fields: tobj_fields, - params: subst, - }) - }; + let ty = unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields, + params: subst, + }); if need_subst { if let Some(wl) = subst_list.as_mut() { wl.push(ty); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index e1c638dd..e3790566 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -5,7 +5,14 @@ use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; -use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext}; +use crate::{ + symbol_resolver::{SymbolResolver, SymbolValue}, + toplevel::{ + helper::PRIMITIVE_DEF_IDS, + numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + TopLevelContext, + }, +}; use itertools::izip; use nac3parser::ast::{ self, @@ -47,6 +54,7 @@ pub struct PrimitiveStore { pub str: Type, pub exception: Type, pub option: Type, + pub ndarray: Type, pub size_t: u32, } @@ -226,7 +234,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } else { let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }), - TypeEnum::TNDArray { .. } => todo!(), + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => todo!(), _ => unreachable!(), }; self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?; @@ -916,10 +924,12 @@ impl<'a> Inferencer<'a> { vec![SymbolValue::U64(ndims)], None, ); - let ret = self.unifier.add_ty(TypeEnum::TNDArray { - ty: self.primitives.float, - ndims - }); + let ret = make_ndarray_ty( + self.unifier, + self.primitives, + Some(self.primitives.float), + Some(ndims), + ); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { @@ -966,11 +976,12 @@ impl<'a> Inferencer<'a> { vec![SymbolValue::U64(ndims)], None, ); - - let ret = self.unifier.add_ty(TypeEnum::TNDArray { - ty, - ndims - }); + let ret = make_ndarray_ty( + self.unifier, + self.primitives, + Some(ty), + Some(ndims), + ); let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { args: vec![ FuncArg { @@ -1252,11 +1263,16 @@ impl<'a> Inferencer<'a> { TypeEnum::TVar { is_const_generic: false, .. } )); - let constrained_ty = self.unifier.add_ty(TypeEnum::TNDArray { ty: dummy_tvar, ndims }); + let constrained_ty = make_ndarray_ty( + self.unifier, + self.primitives, + Some(dummy_tvar), + Some(ndims), + ); self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?; let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else { - panic!("Expected TLiteral for TNDArray.ndims, got {}", self.unifier.stringify(ndims)) + panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims)) }; let ndims = values.iter() @@ -1264,10 +1280,10 @@ impl<'a> Inferencer<'a> { SymbolValue::U64(v) => Ok(v), SymbolValue::U32(v) => Ok(v as u64), SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([ - format!("Expected non-negative literal for TNDArray.ndims, got {v}"), + format!("Expected non-negative literal for ndarray.ndims, got {v}"), ])), SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([ - format!("Expected non-negative literal for TNDArray.ndims, got {v}"), + format!("Expected non-negative literal for ndarray.ndims, got {v}"), ])), _ => unreachable!(), }) @@ -1292,10 +1308,12 @@ impl<'a> Inferencer<'a> { ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(), None, ); - let subscripted_ty = self.unifier.add_ty(TypeEnum::TNDArray { - ty: dummy_tvar, - ndims: ndims_min_one_ty, - }); + let subscripted_ty = make_ndarray_ty( + self.unifier, + self.primitives, + Some(dummy_tvar), + Some(ndims_min_one_ty), + ); Ok(subscripted_ty) } @@ -1315,27 +1333,36 @@ impl<'a> Inferencer<'a> { } let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), - TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }), + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + + make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) + } + _ => unreachable!() }; self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; Ok(list_like_ty) } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { - if let TypeEnum::TNDArray { ndims, .. } = &*self.unifier.get_ty(value.custom.unwrap()) { - self.infer_subscript_ndarray(value, ty, *ndims) - } else { - // the index is a constant, so value can be a sequence. - let ind: Option = (*val).try_into().ok(); - let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; - let map = once(( - ind.into(), - RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), - )) - .collect(); - let seq = self.unifier.add_record(map); - self.constrain(value.custom.unwrap(), seq, &value.location)?; - Ok(ty) + match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + self.infer_subscript_ndarray(value, ty, ndims) + } + _ => { + // the index is a constant, so value can be a sequence. + let ind: Option = (*val).try_into().ok(); + let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?; + let map = once(( + ind.into(), + RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)), + )) + .collect(); + let seq = self.unifier.add_record(map); + self.constrain(value.custom.unwrap(), seq, &value.location)?; + Ok(ty) + } } } _ => { @@ -1351,9 +1378,11 @@ impl<'a> Inferencer<'a> { self.constrain(value.custom.unwrap(), list, &value.location)?; Ok(ty) } - TypeEnum::TNDArray { ndims, .. } => { + TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { + let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?; - self.infer_subscript_ndarray(value, ty, *ndims) + self.infer_subscript_ndarray(value, ty, ndims) } _ => unreachable!(), } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 2abffe3a..44747b7d 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -135,6 +135,11 @@ impl TestEnvironment { fields: HashMap::new(), params: HashMap::new(), }); + let ndarray = unifier.add_ty(TypeEnum::TObj { + obj_id: PRIMITIVE_DEF_IDS.ndarray, + fields: HashMap::new(), + params: HashMap::new(), + }); let primitives = PrimitiveStore { int32, int64, @@ -147,6 +152,7 @@ impl TestEnvironment { uint32, uint64, option, + ndarray, size_t: 64, }; unifier.put_primitive_store(&primitives); @@ -262,6 +268,11 @@ impl TestEnvironment { fields: HashMap::new(), params: HashMap::new(), }); + let ndarray = unifier.add_ty(TypeEnum::TObj { + obj_id: PRIMITIVE_DEF_IDS.ndarray, + fields: HashMap::new(), + params: HashMap::new(), + }); identifier_mapping.insert("None".into(), none); for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] .iter() @@ -296,6 +307,7 @@ impl TestEnvironment { uint32, uint64, option, + ndarray, size_t: 64, }; diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 7a1c2b9d..5cb99b90 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -159,11 +159,6 @@ pub enum TypeEnum { ty: Type, }, - TNDArray { - ty: Type, - ndims: Type, - }, - /// An object type. TObj { /// The [DefintionId] of this object type. @@ -198,34 +193,12 @@ impl TypeEnum { TypeEnum::TLiteral { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", - TypeEnum::TNDArray { .. } => "TNDArray", TypeEnum::TObj { .. } => "TObj", TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", TypeEnum::TFunc { .. } => "TFunc", } } - - /// Returns a [`TypeEnum`] representing a generic `ndarray` type. - /// - /// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic. - /// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic. - #[must_use] - pub fn ndarray( - unifier: &mut Unifier, - dtype: Option, - ndims: Option, - primitives: &PrimitiveStore - ) -> TypeEnum { - let dtype = dtype.unwrap_or_else(|| unifier.get_fresh_var(Some("T".into()), None).0); - let ndims = ndims - .unwrap_or_else(|| unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None).0); - - TypeEnum::TNDArray { - ty: dtype, - ndims, - } - } } pub type SharedUnifier = Arc, u32, Vec)>>; @@ -445,9 +418,6 @@ impl Unifier { TypeEnum::TList { ty } => self .get_instantiations(*ty) .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()), - TypeEnum::TNDArray { ty, ndims } => self - .get_instantiations(*ty) - .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()), TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() }), @@ -505,8 +475,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } - | TVirtual { ty } - | TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars), + | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { @@ -752,7 +721,7 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty } | TNDArray { ty, .. }) => { + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => { for (k, v) in fields { match *k { RecordKey::Int(_) => { @@ -792,7 +761,6 @@ impl Unifier { // If the types don't match, try to implicitly promote integers if !self.unioned(ty, value_ty) { - let num_val = match *value { SymbolValue::I32(v) => v as i128, SymbolValue::I64(v) => v as i128, @@ -864,15 +832,6 @@ impl Unifier { } self.set_a_to_b(a, b); } - (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { - if self.unify_impl(*ty1, *ty2, false).is_err() { - return Self::incompatible_types(a, b) - } - if self.unify_impl(*ndims1, *ndims2, false).is_err() { - return Self::incompatible_types(a, b) - } - self.set_a_to_b(a, b); - } (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { for (k, field) in map { match *k { @@ -1120,13 +1079,6 @@ impl Unifier { TypeEnum::TList { ty } => { format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } - TypeEnum::TNDArray { ty, ndims } => { - format!( - "ndarray[{}, {}]", - self.internal_stringify(*ty, obj_to_name, var_to_name, notes), - self.internal_stringify(*ndims, obj_to_name, var_to_name, notes), - ) - } TypeEnum::TVirtual { ty } => { format!( "virtual[{}]", @@ -1264,19 +1216,6 @@ impl Unifier { TypeEnum::TList { ty } => { self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t })) } - TypeEnum::TNDArray { ty, ndims } => { - let new_ty = self.subst_impl(*ty, mapping, cache); - let new_ndims = self.subst_impl(*ndims, mapping, cache); - - if new_ty.is_some() || new_ndims.is_some() { - Some(self.add_ty(TypeEnum::TNDArray { - ty: new_ty.unwrap_or(*ty), - ndims: new_ndims.unwrap_or(*ndims) - })) - } else { - None - } - } TypeEnum::TVirtual { ty } => self .subst_impl(*ty, mapping, cache) .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), @@ -1447,19 +1386,6 @@ impl Unifier { (TList { ty: ty1 }, TList { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) } - (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { - let ty = self.get_intersection(*ty1, *ty2)?; - let ndims = self.get_intersection(*ndims1, *ndims2)?; - - Ok(if ty.is_some() || ndims.is_some() { - Some(self.add_ty(TNDArray { - ty: ty.unwrap_or(*ty1), - ndims: ndims.unwrap_or(*ndims1), - })) - } else { - None - }) - } (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index ac671591..4ebf2e1a 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -40,7 +40,7 @@ impl Unifier { TypeEnum::TObj { obj_id: id1, params: params1, .. }, TypeEnum::TObj { obj_id: id2, params: params2, .. }, ) => id1 == id2 && self.map_eq(params1, params2), - // TNDArray, TLiteral, TCall and TFunc are not yet implemented + // TLiteral, TCall and TFunc are not yet implemented _ => false, } }