1
0
forked from M-Labs/nac3

artiq: Implement handling for const generic variables

This commit is contained in:
David Mak 2023-12-06 18:28:44 +08:00
parent 031e660f18
commit 983f080ea7
4 changed files with 51 additions and 12 deletions

View File

@ -10,7 +10,7 @@ from embedding_map import EmbeddingMap
__all__ = [ __all__ = [
"Kernel", "KernelInvariant", "virtual", "Kernel", "KernelInvariant", "virtual", "ConstGeneric",
"Option", "Some", "none", "UnwrapNoneError", "Option", "Some", "none", "UnwrapNoneError",
"round64", "floor64", "ceil64", "round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3", "extern", "kernel", "portable", "nac3",
@ -67,6 +67,12 @@ def Some(v: T) -> Option[T]:
none = Option(None) none = Option(None)
class _ConstGenericMarker:
pass
def ConstGeneric(name, constraint):
return TypeVar(name, _ConstGenericMarker, constraint)
def round64(x): def round64(x):
return round(x) return round(x)

View File

@ -277,7 +277,9 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); let end = self
.gen_store_target(ctx, &end_expr, Some("end.addr"))?
.unwrap();
ctx.builder.build_store(end, now); ctx.builder.build_store(end, now);
self.end = Some(end_expr); self.end = Some(end_expr);
self.name_counter += 1; self.name_counter += 1;

View File

@ -75,6 +75,7 @@ pub struct PrimitivePythonId {
list: u64, list: u64,
tuple: u64, tuple: u64,
typevar: u64, typevar: u64,
const_generic_dummy: u64,
none: u64, none: u64,
exception: u64, exception: u64,
generic_alias: (u64, u64), generic_alias: (u64, u64),
@ -877,6 +878,15 @@ impl Nac3 {
.extract() .extract()
.unwrap(), .unwrap(),
typevar: get_attr_id(typing_mod, "TypeVar"), typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_dummy: id_fn
.call1((
builtins_mod.getattr("globals")
.and_then(|v| v.call0())
.and_then(|v| v.get_item("_ConstGenericMarker"))
.unwrap(),
))
.and_then(|v| v.extract())
.unwrap(),
int: get_attr_id(builtins_mod, "int"), int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"), int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"), int64: get_attr_id(numpy_mod, "int64"),

View File

@ -266,10 +266,12 @@ impl InnerResolver {
Ok(Ok(ty)) Ok(Ok(ty))
} }
// handle python objects that represent types themselves /// handle python objects that represent types themselves
// primitives and class types should be themselves, use `ty_id` to check, ///
// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check /// primitives and class types should be themselves, use `ty_id` to check,
// the `bool` value returned indicates whether they are instantiated or not /// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check
///
/// the `bool` value returned indicates whether they are instantiated or not
fn get_pyty_obj_type( fn get_pyty_obj_type(
&self, &self,
py: Python, py: Python,
@ -345,13 +347,21 @@ impl InnerResolver {
} }
} else if ty_ty_id == self.primitive_ids.typevar { } else if ty_ty_id == self.primitive_ids.typevar {
let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap(); let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap();
let constraint_types = { let (constraint_types, is_const_generic) = {
let constraints = pyty.getattr("__constraints__").unwrap(); let constraints = pyty.getattr("__constraints__").unwrap();
let mut result: Vec<Type> = vec![]; let mut result: Vec<Type> = vec![];
let needs_defer = self.deferred_eval_store.needs_defer.load(Relaxed); let needs_defer = self.deferred_eval_store.needs_defer.load(Relaxed);
let mut is_const_generic = false;
for i in 0usize.. { for i in 0usize.. {
if let Ok(constr) = constraints.get_item(i) { if let Ok(constr) = constraints.get_item(i) {
if needs_defer { let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?;
if constr_id == self.primitive_ids.const_generic_dummy {
is_const_generic = true;
continue
}
if !is_const_generic && needs_defer {
result.push(unifier.get_dummy_var().0); result.push(unifier.get_dummy_var().0);
} else { } else {
result.push({ result.push({
@ -375,17 +385,28 @@ impl InnerResolver {
break; break;
} }
} }
if needs_defer {
if !is_const_generic && needs_defer {
self.deferred_eval_store.store.write() self.deferred_eval_store.store.write()
.push((result.clone(), .push((result.clone(),
constraints.extract()?, constraints.extract()?,
pyty.getattr("__name__")?.extract::<String>()? pyty.getattr("__name__")?.extract::<String>()?
)) ))
} }
result
(result, is_const_generic)
}; };
let res =
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; let res = if is_const_generic {
if constraint_types.len() != 1 {
return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len())))
}
unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0
} else {
unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0
};
Ok(Ok((res, true))) Ok(Ok((res, true)))
} else if ty_ty_id == self.primitive_ids.generic_alias.0 } else if ty_ty_id == self.primitive_ids.generic_alias.0
|| ty_ty_id == self.primitive_ids.generic_alias.1 || ty_ty_id == self.primitive_ids.generic_alias.1