diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 5302cf197..fa080f5e7 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap __all__ = [ "Kernel", "KernelInvariant", "virtual", + "Option", "Some", "round64", "floor64", "ceil64", "extern", "kernel", "portable", "nac3", "rpc", "ms", "us", "ns", @@ -32,6 +33,36 @@ class KernelInvariant(Generic[T]): class virtual(Generic[T]): pass +class Option(Generic[T]): + _nac3_option: T + + def __init__(self, v: T): + self._nac3_option = v + + def is_none(self): + return self._nac3_option is None + + def is_some(self): + return not self.is_none() + + def unwrap(self): + return self._nac3_option + + def __repr__(self) -> str: + if self.is_none(): + return "Option(None)" + else: + return "Some({})".format(repr(self._nac3_option)) + + def __str__(self) -> str: + if self.is_none(): + return "None" + else: + return "Some({})".format(str(self._nac3_option)) + +def Some(v: T) -> Option[T]: + return Option(v) + def round64(x): return round(x) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 13bc850d7..a1e6519e9 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -71,6 +71,7 @@ pub struct PrimitivePythonId { exception: u64, generic_alias: (u64, u64), virtual_id: u64, + option: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -352,6 +353,7 @@ impl Nac3 { let builtins_mod = PyModule::import(py, "builtins").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap(); + let type_fn = builtins_mod.getattr("type").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap(); let typing_mod = PyModule::import(py, "typing").unwrap(); let types_mod = PyModule::import(py, "types").unwrap(); @@ -373,7 +375,11 @@ impl Nac3 { get_attr_id(typing_mod, "_GenericAlias"), get_attr_id(types_mod, "GenericAlias"), ), - none: get_attr_id(builtins_mod, "None"), + none: id_fn + .call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),)) + .unwrap() + .extract() + .unwrap(), typevar: get_attr_id(typing_mod, "TypeVar"), int: get_attr_id(builtins_mod, "int"), int32: get_attr_id(numpy_mod, "int32"), @@ -385,6 +391,17 @@ impl Nac3 { list: get_attr_id(builtins_mod, "list"), tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), + option: id_fn + .call1((builtins_mod + .getattr("globals") + .unwrap() + .call0() + .unwrap() + .get_item("Option") + .unwrap(),)) + .unwrap() + .extract() + .unwrap(), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e26435c0f..37db368f3 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -279,6 +279,27 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if ty_id == self.primitive_ids.option { + Ok(Ok((primitives.option, false))) + } else if ty_id == self.primitive_ids.none { + if let TypeEnum::TObj { params, .. } = + unifier.get_ty_immutable(primitives.option).as_ref() + { + let var_map = params + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + } else { + unreachable!() + } + }) + .collect::>(); + Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true))) + } else { + unreachable!("must be tobj") + } } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { let def = defs[def_id.0].read(); if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { @@ -569,6 +590,34 @@ impl InnerResolver { let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) } + // special handling for option type since its class member layout in python side + // is special and cannot be mapped directly to a nac3 type as below + (TypeEnum::TObj { obj_id, params, .. }, false) + if *obj_id == primitives.option.get_obj_id(unifier) => + { + let field_data = match obj.getattr("_nac3_option") { + Ok(d) => d, + // None should be already handled above + Err(_) => unreachable!("cannot be None") + }; + let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?; + let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; + if field_obj_id == zelf_obj_id { + return Ok(Err("self recursive option type is not allowed".into())) + } + let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { + Ok(t) => t, + Err(e) => { + return Ok(Err(format!( + "error when getting type of the option object ({})", + e + ))) + } + }; + let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect(); + let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty); + Ok(Ok(res)) + } (TypeEnum::TObj { params, fields, .. }, false) => { let var_map = params .iter() @@ -795,6 +844,37 @@ impl InnerResolver { let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.option { + match self + .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting value of Option object: {}", + e + )) + })? { + Some(v) => { + let global_str = format!("{}_option", id); + { + if self.global_value_ids.read().contains(&id) { + let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { + ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } else { + self.global_value_ids.write().insert(id); + } + } + let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str); + global.set_initializer(&v); + Ok(Some(global.as_pointer_value().into())) + }, + None => Ok(None), + } + } else if ty_id == self.primitive_ids.none { + // for option type, just a null ptr, whose type needs to be casted in codegen + // according to the type info attached in the ast + Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into())) } else { let id_str = id.to_string(); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3d4d75add..abd9eec0c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -958,16 +958,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let resolver = ctx.resolver.clone(); let val = resolver.get_symbol_value(*id, ctx).unwrap(); // if is tuple, need to deref it to handle tuple as value - if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = ( + // if is option, need to cast pointer to handle None + match ( &*ctx.unifier.get_ty(expr.custom.unwrap()), resolver .get_symbol_value(*id, ctx) .unwrap() .to_basic_value_enum(ctx, generator)?, ) { - ctx.builder.build_load(ptr, "tup_val").into() - } else { - val + (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => { + ctx.builder.build_load(ptr, "tup_val").into() + } + (TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr)) + if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => { + let actual_ptr_ty = ctx.get_llvm_type( + generator, + *params.iter().next().unwrap().1, + ) + .ptr_type(AddressSpace::Generic); + ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into() + } + _ => val, } } }, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 78c866d49..bce050c2e 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -270,7 +270,7 @@ fn get_llvm_type<'ctx>( let result = match &*ty_enum { TObj { obj_id, fields, .. } => { // check to avoid treating primitives other than Option as classes - if obj_id.0 <= 14 { + if obj_id.0 <= 10 { match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) { ( diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 67b89248f..9df1cd1c4 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -54,6 +54,18 @@ pub enum RecordKey { Int(i32), } +impl Type { + // a wrapper function for cleaner code so that we don't need to + // write this long pattern matching just to get the field `obj_id` + pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId { + if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() { + *obj_id + } else { + unreachable!("expect a object type") + } + } +} + impl From<&RecordKey> for StrRef { fn from(r: &RecordKey) -> Self { match r {