artiq: Implement handling for const generic variables

This commit is contained in:
David Mak 2023-12-06 18:28:44 +08:00
parent 638d9f8a30
commit 649874868a
4 changed files with 52 additions and 13 deletions

View File

@ -2,7 +2,7 @@ from inspect import getfullargspec
from functools import wraps
from types import SimpleNamespace
from numpy import int32, int64
from typing import Generic, TypeVar
from typing import Any, Generic, TypeVar
from math import floor, ceil
import nac3artiq
@ -10,7 +10,7 @@ from embedding_map import EmbeddingMap
__all__ = [
"Kernel", "KernelInvariant", "virtual",
"Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3",
@ -67,6 +67,12 @@ def Some(v: T) -> Option[T]:
none = Option(None)
class _ConstGenericDummy:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericDummy, constraint)
def round64(x):
return round(x)

View File

@ -277,7 +277,9 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
let end = self
.gen_store_target(ctx, &end_expr, Some("end.addr"))?
.unwrap();
ctx.builder.build_store(end, now);
self.end = Some(end_expr);
self.name_counter += 1;

View File

@ -75,6 +75,7 @@ pub struct PrimitivePythonId {
list: u64,
tuple: u64,
typevar: u64,
const_generic_dummy: u64,
none: u64,
exception: u64,
generic_alias: (u64, u64),
@ -877,6 +878,15 @@ impl Nac3 {
.extract()
.unwrap(),
typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_dummy: id_fn
.call1((
builtins_mod.getattr("globals")
.and_then(|v| v.call0())
.and_then(|v| v.get_item("_ConstGenericDummy"))
.unwrap(),
))
.and_then(|v| v.extract())
.unwrap(),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),

View File

@ -266,10 +266,12 @@ impl InnerResolver {
Ok(Ok(ty))
}
// handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check,
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
// the `bool` value returned indicates whether they are instantiated or not
/// handle python objects that represent types themselves
///
/// primitives and class types should be themselves, use `ty_id` to check,
/// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
///
/// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type(
&self,
py: Python,
@ -345,13 +347,21 @@ impl InnerResolver {
}
} else if ty_ty_id == self.primitive_ids.typevar {
let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap();
let constraint_types = {
let (constraint_types, is_const_generic) = {
let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![];
let needs_defer = self.deferred_eval_store.needs_defer.load(Relaxed);
let mut is_const_generic = false;
for i in 0usize.. {
if let Ok(constr) = constraints.get_item(i) {
if needs_defer {
let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?;
if constr_id == self.primitive_ids.const_generic_dummy {
is_const_generic = true;
continue
}
if !is_const_generic && needs_defer {
result.push(unifier.get_dummy_var().0);
} else {
result.push({
@ -375,17 +385,28 @@ impl InnerResolver {
break;
}
}
if needs_defer {
if !is_const_generic && needs_defer {
self.deferred_eval_store.store.write()
.push((result.clone(),
constraints.extract()?,
pyty.getattr("__name__")?.extract::<String>()?
))
}
result
(result, is_const_generic)
};
let res =
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0;
let res = if is_const_generic {
if constraint_types.len() != 1 {
return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len())))
}
unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0
} else {
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0
};
Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0
|| ty_ty_id == self.primitive_ids.generic_alias.1