From 983f080ea7748b3ff167280a9e116e27cf49e909 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 6 Dec 2023 18:28:44 +0800 Subject: [PATCH] artiq: Implement handling for const generic variables --- nac3artiq/demo/min_artiq.py | 8 ++++++- nac3artiq/src/codegen.rs | 4 +++- nac3artiq/src/lib.rs | 10 ++++++++ nac3artiq/src/symbol_resolver.rs | 41 ++++++++++++++++++++++++-------- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 1dd5786a..bd5a8ea6 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -10,7 +10,7 @@ from embedding_map import EmbeddingMap __all__ = [ - "Kernel", "KernelInvariant", "virtual", + "Kernel", "KernelInvariant", "virtual", "ConstGeneric", "Option", "Some", "none", "UnwrapNoneError", "round64", "floor64", "ceil64", "extern", "kernel", "portable", "nac3", @@ -67,6 +67,12 @@ def Some(v: T) -> Option[T]: none = Option(None) +class _ConstGenericMarker: + pass + +def ConstGeneric(name, constraint): + return TypeVar(name, _ConstGenericMarker, constraint) + def round64(x): return round(x) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 12375a4c..aad05884 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -277,7 +277,9 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> { node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, custom: Some(ctx.primitives.int64), }; - let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap(); + let end = self + .gen_store_target(ctx, &end_expr, Some("end.addr"))? + .unwrap(); ctx.builder.build_store(end, now); self.end = Some(end_expr); self.name_counter += 1; diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index bef1e73a..29b65ed9 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -75,6 +75,7 @@ pub struct PrimitivePythonId { list: u64, tuple: u64, typevar: u64, + const_generic_dummy: u64, none: u64, exception: u64, generic_alias: (u64, u64), @@ -877,6 +878,15 @@ impl Nac3 { .extract() .unwrap(), typevar: get_attr_id(typing_mod, "TypeVar"), + const_generic_dummy: id_fn + .call1(( + builtins_mod.getattr("globals") + .and_then(|v| v.call0()) + .and_then(|v| v.get_item("_ConstGenericMarker")) + .unwrap(), + )) + .and_then(|v| v.extract()) + .unwrap(), int: get_attr_id(builtins_mod, "int"), int32: get_attr_id(numpy_mod, "int32"), int64: get_attr_id(numpy_mod, "int64"), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 9953cba3..6fd3a7e3 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -266,10 +266,12 @@ impl InnerResolver { Ok(Ok(ty)) } - // handle python objects that represent types themselves - // primitives and class types should be themselves, use `ty_id` to check, - // TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check - // the `bool` value returned indicates whether they are instantiated or not + /// handle python objects that represent types themselves + /// + /// primitives and class types should be themselves, use `ty_id` to check, + /// TypeVars and GenericAlias(`A[int, bool]`) should use `ty_ty_id` to check + /// + /// the `bool` value returned indicates whether they are instantiated or not fn get_pyty_obj_type( &self, py: Python, @@ -345,13 +347,21 @@ impl InnerResolver { } } else if ty_ty_id == self.primitive_ids.typevar { let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap(); - let constraint_types = { + let (constraint_types, is_const_generic) = { let constraints = pyty.getattr("__constraints__").unwrap(); let mut result: Vec = vec![]; let needs_defer = self.deferred_eval_store.needs_defer.load(Relaxed); + + let mut is_const_generic = false; for i in 0usize.. { if let Ok(constr) = constraints.get_item(i) { - if needs_defer { + let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?; + if constr_id == self.primitive_ids.const_generic_dummy { + is_const_generic = true; + continue + } + + if !is_const_generic && needs_defer { result.push(unifier.get_dummy_var().0); } else { result.push({ @@ -375,17 +385,28 @@ impl InnerResolver { break; } } - if needs_defer { + + if !is_const_generic && needs_defer { self.deferred_eval_store.store.write() .push((result.clone(), constraints.extract()?, pyty.getattr("__name__")?.extract::()? )) } - result + + (result, is_const_generic) }; - let res = - unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0; + + let res = if is_const_generic { + if constraint_types.len() != 1 { + return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len()))) + } + + unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0 + } else { + unifier.get_fresh_var_with_range(&constraint_types, Some(name.into()), None).0 + }; + Ok(Ok((res, true))) } else if ty_ty_id == self.primitive_ids.generic_alias.0 || ty_ty_id == self.primitive_ids.generic_alias.1 -- 2.44.2