forked from M-Labs/nac3
Add Auto type support for generic class parameters
This commit is contained in:
@@ -7,33 +7,45 @@ from numpy import int32, int64
|
||||
|
||||
|
||||
@compile
|
||||
class ProtoRev8:
|
||||
"""Simulates a hardware revision with limited features."""
|
||||
core: KernelInvariant[Core]
|
||||
|
||||
def __init__(self, core: Core):
|
||||
self.core = core
|
||||
class CPLDVersion:
|
||||
"""Base class for version-specific CPLD implementations"""
|
||||
def __init__(self, cpld):
|
||||
self.cpld = cpld
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, data: int32):
|
||||
def cfg_write(self, cfg: int32):
|
||||
pass
|
||||
|
||||
@kernel
|
||||
def cfg_att_en(self, channel: int32, on: bool):
|
||||
pass
|
||||
|
||||
|
||||
@compile
|
||||
class ProtoRev8(CPLDVersion):
|
||||
"""Simulates ARTIQ's ProtoRev8 - with self-referential CPLD type."""
|
||||
|
||||
# Self-referential type ProtoRev8 references CPLD[ProtoRev8]
|
||||
cpld: KernelInvariant[CPLD[ProtoRev8]]
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.cpld.cfg_reg = cfg
|
||||
|
||||
@kernel
|
||||
def cfg_att_en(self, channel: int32, on: bool):
|
||||
raise ValueError("cfg_att_en not supported on ProtoRev8")
|
||||
|
||||
|
||||
@compile
|
||||
class ProtoRev9:
|
||||
"""Simulates a hardware revision with full features."""
|
||||
core: KernelInvariant[Core]
|
||||
|
||||
def __init__(self, core: Core):
|
||||
self.core = core
|
||||
class ProtoRev9(CPLDVersion):
|
||||
"""Simulates ARTIQ's ProtoRev9 - with self-referential CPLD type."""
|
||||
# Self-referential type ProtoRev9 references CPLD[ProtoRev9]
|
||||
cpld: KernelInvariant[CPLD[ProtoRev9]]
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, data: int32):
|
||||
pass
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.cpld.cfg_reg = cfg
|
||||
|
||||
@kernel
|
||||
def cfg_att_en(self, channel: int32, on: bool):
|
||||
@@ -50,14 +62,18 @@ class CPLD(Generic[V]):
|
||||
version: KernelInvariant[V]
|
||||
cfg_reg: Kernel[int32]
|
||||
|
||||
def __init__(self, core: Core, version: V):
|
||||
def __init__(self, core: Core, version_cls, proto_rev: int32 = int32(0x09)):
|
||||
self.core = core
|
||||
self.version = version
|
||||
self.cfg_reg = int32(0)
|
||||
# version is created with self-reference
|
||||
if proto_rev == int32(0x08):
|
||||
self.version = ProtoRev8(self)
|
||||
else:
|
||||
self.version = ProtoRev9(self)
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.cfg_reg = cfg
|
||||
self.version.cfg_write(cfg)
|
||||
|
||||
@kernel
|
||||
def set_att_en(self, channel: int32, on: bool):
|
||||
@@ -66,8 +82,9 @@ class CPLD(Generic[V]):
|
||||
|
||||
@compile
|
||||
class RegIOUpdate:
|
||||
"""Simulates ARTIQ's RegIOUpdate - constructed with required args."""
|
||||
cpld: KernelInvariant[CPLD[ProtoRev9]]
|
||||
"""Simulates ARTIQ's RegIOUpdate - KEY: uses CPLD[Auto] not CPLD[ProtoRev9]."""
|
||||
# Auto in a non-generic class field
|
||||
cpld: KernelInvariant[CPLD[Auto]]
|
||||
chip_select: KernelInvariant[int32]
|
||||
|
||||
def __init__(self, cpld, chip_select):
|
||||
@@ -85,14 +102,10 @@ IoUpdateT = TypeVar("IoUpdateT", RegIOUpdate, TTLOut)
|
||||
|
||||
@compile
|
||||
class AD9910(Generic[IoUpdateT]):
|
||||
"""Simulates ARTIQ's AD9910 - generic over IO update type.
|
||||
|
||||
This is the key test case: AD9910 has required constructor args
|
||||
(core, cpld, chip_select) and a field `cpld: CPLD[Auto]` that
|
||||
nac3 should be able to infer from the runtime value.
|
||||
"""
|
||||
"""Simulates ARTIQ's AD9910 - generic over IO update type."""
|
||||
core: KernelInvariant[Core]
|
||||
cpld: KernelInvariant[CPLD[Auto]] # <-- This should work: infer CPLD variant from runtime
|
||||
# Auto in generic class field
|
||||
cpld: KernelInvariant[CPLD[Auto]]
|
||||
chip_select: KernelInvariant[int32]
|
||||
io_update: KernelInvariant[IoUpdateT]
|
||||
|
||||
@@ -111,6 +124,46 @@ class AD9910(Generic[IoUpdateT]):
|
||||
return self.chip_select
|
||||
|
||||
|
||||
@compile
|
||||
class SUServo:
|
||||
"""Simulates ARTIQ's SUServo - KEY: uses list[AD9910[Auto]] and list[CPLD[Auto]]."""
|
||||
core: KernelInvariant[Core]
|
||||
# list of generic types with Auto
|
||||
ddses: KernelInvariant[list[AD9910[Auto]]]
|
||||
cplds: KernelInvariant[list[CPLD[Auto]]]
|
||||
|
||||
def __init__(self, core: Core, ddses, cplds):
|
||||
self.core = core
|
||||
self.ddses = ddses
|
||||
self.cplds = cplds
|
||||
|
||||
@kernel
|
||||
def init(self):
|
||||
for i in range(len(self.cplds)):
|
||||
cpld = self.cplds[i]
|
||||
dds = self.ddses[i]
|
||||
cpld.cfg_write(int32(0))
|
||||
dds.init()
|
||||
|
||||
|
||||
@compile
|
||||
class Channel:
|
||||
"""Simulates ARTIQ's SUServo Channel - KEY: uses AD9910[Auto]."""
|
||||
core: KernelInvariant[Core]
|
||||
servo: KernelInvariant[SUServo]
|
||||
# AD9910[Auto] in non-generic class
|
||||
dds: KernelInvariant[AD9910[Auto]]
|
||||
|
||||
def __init__(self, servo: SUServo, channel: int32):
|
||||
self.core = servo.core
|
||||
self.servo = servo
|
||||
self.dds = servo.ddses[channel]
|
||||
|
||||
@kernel
|
||||
def set_dds(self) -> int32:
|
||||
return self.dds.get_chip_select()
|
||||
|
||||
|
||||
@compile
|
||||
class Inner:
|
||||
core: KernelInvariant[Core]
|
||||
@@ -144,6 +197,9 @@ class AutoDemo:
|
||||
# This mirrors ARTIQ's AD9910.cpld: CPLD[Auto] pattern
|
||||
dds: KernelInvariant[Auto] # inferred as AD9910[RegIOUpdate]
|
||||
|
||||
servo: KernelInvariant[SUServo]
|
||||
channel: KernelInvariant[Channel]
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
self.x_auto = int32(42)
|
||||
@@ -154,8 +210,14 @@ class AutoDemo:
|
||||
self.obj_auto = Inner(self.core, int32(99))
|
||||
self.list_auto = [int32(1), int32(2), int32(3)]
|
||||
|
||||
cpld = CPLD(self.core, ProtoRev9(self.core))
|
||||
self.dds = AD9910(self.core, cpld, int32(4))
|
||||
cpld1 = CPLD(self.core, ProtoRev9, int32(0x09))
|
||||
cpld2 = CPLD(self.core, ProtoRev9, int32(0x09))
|
||||
dds1 = AD9910(self.core, cpld1, int32(4))
|
||||
dds2 = AD9910(self.core, cpld2, int32(5))
|
||||
self.dds = dds1
|
||||
|
||||
self.servo = SUServo(self.core, [dds1, dds2], [cpld1, cpld2])
|
||||
self.channel = Channel(self.servo, int32(0))
|
||||
|
||||
@kernel
|
||||
def test_auto_int(self) -> int32:
|
||||
@@ -192,6 +254,16 @@ class AutoDemo:
|
||||
self.dds.init()
|
||||
return self.dds.get_chip_select()
|
||||
|
||||
@kernel
|
||||
def test_suservo_list_auto(self):
|
||||
"""Test list[AD9910[Auto]] and list[CPLD[Auto]] - mirrors ARTIQ SUServo."""
|
||||
self.servo.init()
|
||||
|
||||
@kernel
|
||||
def test_channel_auto(self) -> int32:
|
||||
"""Test AD9910[Auto] via Channel - mirrors ARTIQ SUServo Channel."""
|
||||
return self.channel.set_dds()
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
x = self.test_auto_int()
|
||||
@@ -201,6 +273,8 @@ class AutoDemo:
|
||||
o = self.test_auto_object()
|
||||
l = self.test_auto_list()
|
||||
n = self.test_auto_nested_generic()
|
||||
self.test_suservo_list_auto()
|
||||
c = self.test_channel_auto()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
103
nac3artiq/demo/auto_type_stress.py
Normal file
103
nac3artiq/demo/auto_type_stress.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from min_artiq import *
|
||||
from min_artiq import Auto
|
||||
from typing import Generic, TypeVar
|
||||
from numpy import int32
|
||||
|
||||
# Self-referential ProtoRev types
|
||||
|
||||
class ProtoRev8:
|
||||
cpld: KernelInvariant[CPLD[ProtoRev8]]
|
||||
|
||||
def __init__(self, cpld):
|
||||
self.cpld = cpld
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.cpld.cfg_reg = cfg
|
||||
|
||||
@kernel
|
||||
def sta_read(self) -> int32:
|
||||
return self.cpld.cfg_reg
|
||||
|
||||
@compile
|
||||
class ProtoRev9:
|
||||
cpld: KernelInvariant[CPLD[ProtoRev9]]
|
||||
|
||||
def __init__(self, cpld):
|
||||
self.cpld = cpld
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.cpld.cfg_reg = cfg
|
||||
|
||||
@kernel
|
||||
def sta_read(self) -> int32:
|
||||
return self.cpld.cfg_reg
|
||||
|
||||
|
||||
V = TypeVar("V", ProtoRev8, ProtoRev9)
|
||||
|
||||
|
||||
@compile
|
||||
class CPLD(Generic[V]):
|
||||
core: KernelInvariant[Core]
|
||||
version: KernelInvariant[V]
|
||||
cfg_reg: Kernel[int32]
|
||||
|
||||
def __init__(self, core: Core, version_cls, proto_rev: int32 = int32(0x09)):
|
||||
self.core = core
|
||||
self.cfg_reg = int32(0)
|
||||
if proto_rev == int32(0x08):
|
||||
self.version = ProtoRev8(self)
|
||||
else:
|
||||
self.version = ProtoRev9(self)
|
||||
|
||||
@kernel
|
||||
def cfg_write(self, cfg: int32):
|
||||
self.version.cfg_write(cfg)
|
||||
|
||||
@kernel
|
||||
def sta_read(self) -> int32:
|
||||
return self.version.sta_read()
|
||||
|
||||
@kernel
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
@compile
|
||||
class UninstantiatedDevice:
|
||||
core: KernelInvariant[Core]
|
||||
cpld: KernelInvariant[CPLD[Auto]]
|
||||
chip_select: KernelInvariant[int32]
|
||||
|
||||
def __init__(self, core: Core, cpld, chip_select: int32):
|
||||
self.core = core
|
||||
self.cpld = cpld
|
||||
self.chip_select = chip_select
|
||||
|
||||
@kernel
|
||||
def init(self):
|
||||
self.cpld.cfg_write(self.chip_select)
|
||||
|
||||
@kernel
|
||||
def get_status(self) -> int32:
|
||||
return self.cpld.sta_read()
|
||||
|
||||
@compile
|
||||
class StressDemo:
|
||||
core: KernelInvariant[Core]
|
||||
cpld: KernelInvariant[CPLD[Auto]]
|
||||
|
||||
def __init__(self):
|
||||
self.core = Core()
|
||||
self.cpld = CPLD(self.core, ProtoRev9, int32(0x09))
|
||||
|
||||
@kernel
|
||||
def run(self):
|
||||
self.cpld.init()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
StressDemo().run()
|
||||
@@ -1073,7 +1073,6 @@ impl InnerResolver {
|
||||
Ok(Ok(res))
|
||||
}
|
||||
(TypeEnum::TObj { params, fields, .. }, false) => {
|
||||
self.pyid_to_type.write().insert(py_obj_id, extracted_ty);
|
||||
let var_map = into_var_map(iter_type_vars(params).map(|tvar| {
|
||||
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(tvar.ty)
|
||||
else {
|
||||
@@ -1084,6 +1083,9 @@ impl InnerResolver {
|
||||
let ty = unifier.get_fresh_var_with_range(range, *name, *loc).ty;
|
||||
TypeVar { id: *id, ty }
|
||||
}));
|
||||
// Cache with fresh type variables to prevent circular references
|
||||
let cache_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
|
||||
self.pyid_to_type.write().insert(py_obj_id, cache_ty);
|
||||
let mut instantiate_obj = || {
|
||||
// loop through non-function fields of the class to get the instantiated value
|
||||
for field in fields {
|
||||
@@ -2015,9 +2017,10 @@ impl SymbolResolver for Resolver {
|
||||
Some(Python::attach(|py| -> PyResult<Result<Type, String>> {
|
||||
let module = self.0.module.bind(py);
|
||||
let class_name_str = class_name.to_string();
|
||||
let simple_name = class_name_str.rsplit_once('.').map_or(class_name_str.as_str(), |(_, n)| n);
|
||||
|
||||
// Get the class object from the module
|
||||
let Ok(class_obj) = module.getattr(class_name_str.as_str()) else {
|
||||
let Ok(class_obj) = module.getattr(simple_name) else {
|
||||
return Ok(Err(format!(
|
||||
"Auto type error: cannot find class `{class_name}` in module"
|
||||
)));
|
||||
|
||||
@@ -828,18 +828,23 @@ impl TopLevelComposer {
|
||||
|
||||
for (class_def, class_ast) in def_list.iter().skip(self.builtin_num) {
|
||||
if class_ast.is_some() && matches!(&*class_def.read(), TopLevelDef::Class { .. }) {
|
||||
// Collect new entries from this class into a temporary map
|
||||
let mut new_entries: HashMap<Type, TypeAnnotation> = HashMap::new();
|
||||
if let Err(e) = Self::analyze_single_class_methods_fields(
|
||||
class_def,
|
||||
&class_ast.as_ref().unwrap().node,
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
&mut type_var_to_concrete_def,
|
||||
&mut new_entries,
|
||||
&self.builtin_registry,
|
||||
) {
|
||||
errors.extend(e);
|
||||
}
|
||||
|
||||
// Merge new entries into the main map
|
||||
type_var_to_concrete_def.extend(new_entries.iter().map(|(k, v)| (*k, v.clone())));
|
||||
|
||||
// The errors need to be reported before copying methods from parent to child classes
|
||||
if !errors.is_empty() {
|
||||
return Err(errors);
|
||||
@@ -865,23 +870,22 @@ impl TopLevelComposer {
|
||||
}
|
||||
|
||||
let mut subst_list = Some(Vec::new());
|
||||
// unification of previously assigned typevar
|
||||
let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> {
|
||||
let target_ty = get_type_from_type_annotation_kinds(
|
||||
for (ty, def) in &new_entries {
|
||||
match get_type_from_type_annotation_kinds(
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
&def,
|
||||
def,
|
||||
&mut subst_list,
|
||||
)?;
|
||||
unifier
|
||||
.unify(ty, target_ty)
|
||||
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
|
||||
Ok(())
|
||||
};
|
||||
for (ty, def) in &type_var_to_concrete_def {
|
||||
if let Err(e) = unification_helper(*ty, def.clone()) {
|
||||
errors.extend(e);
|
||||
) {
|
||||
Ok(target_ty) => {
|
||||
if let Err(e) = unifier.unify(*ty, target_ty) {
|
||||
errors.insert(e.to_display(unifier).to_string());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
for ty in subst_list.unwrap() {
|
||||
@@ -910,6 +914,51 @@ impl TopLevelComposer {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut subst_list = Some(Vec::new());
|
||||
for (ty, def) in &type_var_to_concrete_def {
|
||||
match get_type_from_type_annotation_kinds(
|
||||
&temp_def_list,
|
||||
unifier,
|
||||
primitives_store,
|
||||
def,
|
||||
&mut subst_list,
|
||||
) {
|
||||
Ok(target_ty) => {
|
||||
if let Err(e) = unifier.unify(*ty, target_ty) {
|
||||
errors.insert(e.to_display(unifier).to_string());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
errors.extend(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
for ty in subst_list.unwrap() {
|
||||
let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
let mut new_fields = HashMap::new();
|
||||
let mut need_subst = false;
|
||||
for (name, (ty, mutable)) in fields {
|
||||
let substituted = unifier.subst(*ty, params);
|
||||
need_subst |= substituted.is_some();
|
||||
new_fields.insert(*name, (substituted.unwrap_or(*ty), *mutable));
|
||||
}
|
||||
if need_subst {
|
||||
let new_ty = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: *obj_id,
|
||||
params: params.clone(),
|
||||
fields: new_fields,
|
||||
});
|
||||
if let Err(e) = unifier.unify(ty, new_ty) {
|
||||
errors.insert(e.to_display(unifier).to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (def, _) in def_list.iter().skip(self.builtin_num) {
|
||||
match &*def.read() {
|
||||
TopLevelDef::Class { resolver: Some(resolver), .. }
|
||||
@@ -1965,12 +2014,14 @@ impl TopLevelComposer {
|
||||
// None if is not class method
|
||||
let uninst_self_type = {
|
||||
if let Some(class_id) = method_class.get(&DefinitionId(id)) {
|
||||
let TopLevelDef::Class { type_vars, .. } =
|
||||
let TopLevelDef::Class { type_vars, fields, .. } =
|
||||
&*definition_ast_list.get(class_id.0).unwrap().0.read()
|
||||
else {
|
||||
unreachable!("must be class def")
|
||||
};
|
||||
|
||||
let field_types: Vec<Type> = fields.iter().map(|(_, ty, _)| *ty).collect();
|
||||
|
||||
let ty_ann = make_self_type_annotation(type_vars, *class_id);
|
||||
let self_ty = get_type_from_type_annotation_kinds(
|
||||
&def_list,
|
||||
@@ -1986,11 +2037,20 @@ impl TopLevelComposer {
|
||||
|
||||
(*id, *ty)
|
||||
}));
|
||||
Some((self_ty, type_vars.clone()))
|
||||
Some((self_ty, type_vars.clone(), field_types))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Collect TVars from class field types so that Auto TVars are treated as
|
||||
// bound by is_concrete in function_check.
|
||||
let mut field_tvars: Vec<Type> = Vec::new();
|
||||
if let Some((_, _, ref field_types)) = uninst_self_type {
|
||||
for &ty in field_types {
|
||||
unifier.collect_tvar_handles(ty, &mut field_tvars);
|
||||
}
|
||||
}
|
||||
// carefully handle those with bounds, without bounds and no typevars
|
||||
// if class methods, `vars` also contains all class typevars here
|
||||
let (type_var_subst_comb, no_range_vars) = {
|
||||
@@ -2037,7 +2097,7 @@ impl TopLevelComposer {
|
||||
.collect_vec()
|
||||
};
|
||||
let self_type = {
|
||||
uninst_self_type.clone().map(|(self_type, type_vars)| {
|
||||
uninst_self_type.clone().map(|(self_type, type_vars, _)| {
|
||||
let subst_for_self = {
|
||||
let class_ty_var_ids = type_vars
|
||||
.iter()
|
||||
@@ -2083,7 +2143,11 @@ impl TopLevelComposer {
|
||||
Some(inst_ret)
|
||||
},
|
||||
// NOTE: allowed type vars
|
||||
bound_variables: no_range_vars.clone(),
|
||||
bound_variables: {
|
||||
let mut bv = no_range_vars.clone();
|
||||
bv.extend(&field_tvars);
|
||||
bv
|
||||
},
|
||||
},
|
||||
unifier,
|
||||
variable_mapping: {
|
||||
|
||||
@@ -625,6 +625,24 @@ impl Unifier {
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to collect Auto type variables from class field types so they can be
|
||||
/// added to `bound_variables` for the [`is_concrete`][Self::is_concrete] check.
|
||||
pub fn collect_tvar_handles(&mut self, ty: Type, result: &mut Vec<Type>) {
|
||||
let to_recurse: Vec<Type> = match self.get_ty(ty).as_ref() {
|
||||
TypeEnum::TVar { .. } => {
|
||||
result.push(ty);
|
||||
return;
|
||||
}
|
||||
TypeEnum::TObj { params, .. } => params.values().copied().collect(),
|
||||
TypeEnum::TTuple { ty: types, .. } => types.clone(),
|
||||
TypeEnum::TVirtual { ty: inner } => vec![*inner],
|
||||
_ => return,
|
||||
};
|
||||
for t in to_recurse {
|
||||
self.collect_tvar_handles(t, result);
|
||||
}
|
||||
}
|
||||
|
||||
fn restore_snapshot(&mut self) {
|
||||
if let Some(snapshot) = self.snapshot.take() {
|
||||
self.unification_table.restore_snapshot(snapshot);
|
||||
|
||||
Reference in New Issue
Block a user