From da4dec08a5d99936b9febca046e36aca57f68851 Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 25 Jun 2024 18:39:34 +0800 Subject: [PATCH] core/typedef: Add GenericObjectType --- nac3artiq/src/symbol_resolver.rs | 63 ++++++++++------- nac3core/src/toplevel/builtins.rs | 17 ++--- nac3core/src/typecheck/typedef/mod.rs | 98 ++++++++++++++++++++++++++- 3 files changed, 141 insertions(+), 37 deletions(-) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 62597b8..5289e6d 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -4,6 +4,7 @@ use inkwell::{ AddressSpace, }; use itertools::Itertools; +use nac3core::typecheck::typedef::{GenericObjectType, GenericTypeAdapter}; use nac3core::{ codegen::{ classes::{NDArrayType, ProxyType}, @@ -17,7 +18,7 @@ use nac3core::{ }, typecheck::{ type_inferencer::PrimitiveStore, - typedef::{into_var_map, iter_type_vars, Type, TypeEnum, TypeVar, Unifier, VarMap}, + typedef::{Type, TypeEnum, TypeVar, Unifier, VarMap}, }, }; use nac3parser::ast::{self, StrRef}; @@ -767,21 +768,23 @@ impl InnerResolver { // if is `none` let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; if zelf_id == self.primitive_ids.none { - let ty_enum = unifier.get_ty_immutable(primitives.option); - let TypeEnum::TObj { params, .. } = ty_enum.as_ref() else { - unreachable!("must be tobj") - }; + let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier); + let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| { + tvar_iter + .map(|tvar| { + let TypeEnum::TVar { id, range, name, loc, .. } = + &*unifier.get_ty(tvar.ty) + else { + unreachable!() + }; - let var_map = into_var_map(iter_type_vars(params).map(|tvar| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty) - else { - unreachable!() - }; - - assert_eq!(*id, tvar.id); - let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; - TypeVar { id: *id, ty } - })); + assert_eq!(*id, tvar.id); + let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; + TypeVar { id: *id, ty } + }) + .map(TypeVar::into) + .collect::() + }); return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())); } @@ -797,19 +800,26 @@ impl InnerResolver { let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty); Ok(Ok(res)) } - (TypeEnum::TObj { params, fields, .. }, false) => { + (TypeEnum::TObj { fields, .. }, false) => { self.pyid_to_type.write().insert(py_obj_id, extracted_ty); - let var_map = into_var_map(iter_type_vars(params).map(|tvar| { - let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty) - else { - unreachable!() - }; + let extracted_ty = GenericTypeAdapter::create(extracted_ty, unifier); + let var_map = extracted_ty.iter_var_map(unifier, |tvar_iter, unifier| { + tvar_iter + .map(|tvar| { + let TypeEnum::TVar { id, range, name, loc, .. } = + &*unifier.get_ty(tvar.ty) + else { + unreachable!() + }; - assert_eq!(*id, tvar.id); - let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; - TypeVar { id: *id, ty } - })); - let mut instantiate_obj = || { + assert_eq!(*id, tvar.id); + let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty; + TypeVar { id: *id, ty } + }) + .map(TypeVar::into) + .collect::() + }); + let instantiate_obj = || { // loop through non-function fields of the class to get the instantiated value for field in fields { let name: String = (*field.0).into(); @@ -844,6 +854,7 @@ impl InnerResolver { return Ok(Err("object is not of concrete type".into())); } } + let extracted_ty = extracted_ty.into(); let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); Ok(Ok(extracted_ty)) diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 2524b9a..121038d 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -11,6 +11,7 @@ use inkwell::{ use itertools::Either; use strum::IntoEnumIterator; +use crate::typecheck::typedef::{GenericObjectType, GenericTypeAdapter}; use crate::{ codegen::{ builtin_fns, @@ -25,7 +26,7 @@ use crate::{ }, symbol_resolver::SymbolValue, toplevel::{helper::PrimDef, numpy::make_ndarray_ty}, - typecheck::typedef::{into_var_map, iter_type_vars, TypeVar, VarMap}, + typecheck::typedef::{into_var_map, TypeVar, VarMap}, }; use super::*; @@ -345,23 +346,23 @@ impl<'a> BuiltinBuilder<'a> { // Option-related let (is_some_ty, unwrap_ty, option_tvar) = - if let TypeEnum::TObj { fields, params, .. } = unifier.get_ty(option).as_ref() { + if let TypeEnum::TObj { fields, .. } = &*unifier.get_ty(option) { + let option = GenericTypeAdapter::create(option, unifier); ( *fields.get(&PrimDef::OptionIsSome.simple_name().into()).unwrap(), *fields.get(&PrimDef::OptionUnwrap.simple_name().into()).unwrap(), - iter_type_vars(params).next().unwrap(), + option.get_var_at(unifier, 0).unwrap(), ) } else { unreachable!() }; - let TypeEnum::TObj { fields: ndarray_fields, params: ndarray_params, .. } = - &*unifier.get_ty(ndarray) - else { + let TypeEnum::TObj { fields: ndarray_fields, .. } = &*unifier.get_ty(ndarray) else { unreachable!() }; - let ndarray_dtype_tvar = iter_type_vars(ndarray_params).next().unwrap(); - let ndarray_ndims_tvar = iter_type_vars(ndarray_params).nth(1).unwrap(); + let ndarray = GenericTypeAdapter::create(ndarray, unifier); + let ndarray_dtype_tvar = ndarray.get_var_at(unifier, 0).unwrap(); + let ndarray_ndims_tvar = ndarray.get_var_at(unifier, 1).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&PrimDef::NDArrayCopy.simple_name().into()).unwrap(); let ndarray_fill_ty = diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 594ecce..d268455 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -89,6 +89,24 @@ pub struct TypeVar { pub ty: Type, } +impl From<(TypeVarId, Type)> for TypeVar { + fn from((id, ty): (TypeVarId, Type)) -> Self { + TypeVar { id, ty } + } +} + +impl From<(&TypeVarId, &Type)> for TypeVar { + fn from((id, ty): (&TypeVarId, &Type)) -> Self { + TypeVar { id: *id, ty: *ty } + } +} + +impl From for (TypeVarId, Type) { + fn from(value: TypeVar) -> Self { + (value.id, value.ty) + } +} + /// The mapping between [`TypeVarId`] and [unifier type][`Type`]. pub type VarMap = IndexMapping; @@ -102,9 +120,83 @@ where vars.into_iter().map(|var| (var.id, var.ty)).collect() } -/// Get an iterator of [`TypeVar`]s from a [`VarMap`] -pub fn iter_type_vars(var_map: &VarMap) -> impl Iterator + '_ { - var_map.iter().map(|(&id, &ty)| TypeVar { id, ty }) +/// A trait representing a possibly generic object type. +pub trait GenericObjectType +where + Self: Sized, +{ + fn try_create(ty: Type, unifier: &mut Unifier) -> Option; + + /// Creates an instance from a [`Type`]. + #[must_use] + fn create(ty: Type, unifier: &mut Unifier) -> Self { + Self::try_create(ty, unifier).unwrap() + } + + /// Returns the [`Type`] underlying this instance. + #[must_use] + fn get_type(&self) -> Type; + + /// See [`Type::obj_id`]. + #[must_use] + fn obj_id(&self, unifier: &Unifier) -> DefinitionId { + self.get_type().obj_id(unifier).unwrap() + } + + /// Returns a copy of the [`VarMap`] of this object type. + #[must_use] + fn var_map(&self, unifier: &mut Unifier) -> VarMap { + let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else { + unreachable!() + }; + + params.clone() + } + + /// Creates an iterator over the [`VarMap`] of this object type, applying `iter_fn` on the + /// created [`Iterator`]. + #[must_use] + fn iter_var_map, &mut Unifier) -> R>( + &self, + unifier: &mut Unifier, + iter_fn: IterFn, + ) -> R { + let TypeEnum::TObj { params, .. } = &*unifier.get_ty(self.get_type()) else { + unreachable!() + }; + + let res = iter_fn(&mut params.iter().map(TypeVar::from), unifier); + res + } + + /// Returns the [`TypeVar`] instance at the given index. + #[must_use] + fn get_var_at(&self, unifier: &mut Unifier, i: usize) -> Option { + self.iter_var_map(unifier, |iter, _| iter.nth(i)) + } +} + +impl From for Type { + fn from(value: T) -> Self { + value.get_type() + } +} + +/// An adapter that converts [`Type`] into +pub struct GenericTypeAdapter(Type); + +impl GenericObjectType for GenericTypeAdapter { + fn try_create(ty: Type, unifier: &mut Unifier) -> Option { + if let TypeEnum::TObj { .. } = &*unifier.get_ty_immutable(ty) { + Some(GenericTypeAdapter(ty)) + } else { + None + } + } + + fn get_type(&self) -> Type { + self.0 + } } #[derive(Clone)]