Option type support #224

Merged
sb10q merged 8 commits from option into master 2022-03-26 15:09:15 +08:00
6 changed files with 157 additions and 6 deletions
Showing only changes of commit 71d92a7a18 - Show all commits

View File

@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap
__all__ = [
"Kernel", "KernelInvariant", "virtual",
"Option", "Some",
"round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3",
"rpc", "ms", "us", "ns",
@ -32,6 +33,36 @@ class KernelInvariant(Generic[T]):
class virtual(Generic[T]):
pass
class Option(Generic[T]):
_nac3_option: T
def __init__(self, v: T):
self._nac3_option = v
def is_none(self):
return self._nac3_option is None
def is_some(self):
Outdated
Review

Blank lines, __str__/__repr__

Blank lines, ``__str__``/``__repr__``

I am not fully sure the best content to put into the returned the string.. currently a simple implementation based on something like this, and __str__ will return the same thing as __repr__

I am not fully sure the best content to put into the returned the string.. currently a simple implementation based on something like [this](https://github.com/m-labs/artiq/blob/nac3/artiq/language/scan.py#L270-L274), and `__str__` will return the same thing as `__repr__`
Outdated
Review

"The str() function returns a user-friendly description of an object. The repr() method returns a developer-friendly string representation of an object."

I propose:

str: Some(x) / None (like Rust)
repr: Some(x) / Option(None) (to clarify the difference with the Python None)

"The str() function returns a user-friendly description of an object. The repr() method returns a developer-friendly string representation of an object." I propose: str: ``Some(x)`` / ``None`` (like Rust) repr: ``Some(x)`` / ``Option(None)`` (to clarify the difference with the Python ``None``)
return not self.is_none()
def unwrap(self):
return self._nac3_option
ychenfo marked this conversation as resolved Outdated
Outdated
Review

You want to raise an exception if it's None to match the device behavior.

You want to raise an exception if it's ``None`` to match the device behavior.

thanks! now it will raise a ValueError("unwrap on none")

thanks! now it will raise a `ValueError("unwrap on none")`
Outdated
Review

I would suggest a dedicated exception type.

I would suggest a dedicated exception type.

added a new dedicated exception UnwrapNoneError

added a new dedicated exception `UnwrapNoneError`
def __repr__(self) -> str:
if self.is_none():
return "Option(None)"
ychenfo marked this conversation as resolved Outdated
Outdated
Review

Just none is fine.

Just ``none`` is fine.
else:
return "Some({})".format(repr(self._nac3_option))
Outdated
Review

Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening.

Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening.
Outdated
Review

And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)".

Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the Option object) and produce a good enough backtrace without anything special.

You could put an underscore _nac3_option (see Python naming conventions) to highlight the fact that the user is not supposed to mutate the Option.

And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)". Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the ``Option`` object) and produce a good enough backtrace without anything special. You could put an underscore ``_nac3_option`` (see Python naming conventions) to highlight the fact that the user is not supposed to mutate the ``Option``.

Thanks for pointing this out! I have force pushed to update the __str__ and __repr__, remove the recursion check and use _nac3_option

Thanks for pointing this out! I have force pushed to update the `__str__` and `__repr__`, remove the recursion check and use `_nac3_option`
def __str__(self) -> str:
if self.is_none():
return "None"
else:
return "Some({})".format(str(self._nac3_option))
def Some(v: T) -> Option[T]:
return Option(v)
def round64(x):
return round(x)

View File

@ -71,6 +71,7 @@ pub struct PrimitivePythonId {
exception: u64,
generic_alias: (u64, u64),
virtual_id: u64,
option: u64,
}
type TopLevelComponent = (Stmt, String, PyObject);
@ -352,6 +353,7 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap();
let type_fn = builtins_mod.getattr("type").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap();
@ -373,7 +375,11 @@ impl Nac3 {
get_attr_id(typing_mod, "_GenericAlias"),
get_attr_id(types_mod, "GenericAlias"),
),
none: get_attr_id(builtins_mod, "None"),
none: id_fn
.call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: get_attr_id(typing_mod, "TypeVar"),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
@ -385,6 +391,17 @@ impl Nac3 {
list: get_attr_id(builtins_mod, "list"),
tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"),
option: id_fn
.call1((builtins_mod
.getattr("globals")
.unwrap()
.call0()
.unwrap()
.get_item("Option")
.unwrap(),))
.unwrap()
.extract()
.unwrap(),
};
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();

View File

@ -279,6 +279,27 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none {
if let TypeEnum::TObj { params, .. } =
unifier.get_ty_immutable(primitives.option).as_ref()
{
let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true)))
} else {
unreachable!("must be tobj")
}
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
let def = defs[def_id.0].read();
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
@ -569,6 +590,34 @@ impl InnerResolver {
let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
}
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.get_obj_id(unifier) =>
{
let field_data = match obj.getattr("_nac3_option") {
Ok(d) => d,
// None should be already handled above
Err(_) => unreachable!("cannot be None")
};
let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?;
let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if field_obj_id == zelf_obj_id {
return Ok(Err("self recursive option type is not allowed".into()))
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of the option object ({})",
e
)))
}
};
let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect();
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
Ok(Ok(res))
}
(TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params
.iter()
@ -795,6 +844,37 @@ impl InnerResolver {
let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.option {
match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting value of Option object: {}",
e
))
})? {
Some(v) => {
let global_str = format!("{}_option", id);
{
if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
});
return Ok(Some(global.as_pointer_value().into()));
} else {
self.global_value_ids.write().insert(id);
}
}
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
global.set_initializer(&v);
Ok(Some(global.as_pointer_value().into()))
},
None => Ok(None),
}
} else if ty_id == self.primitive_ids.none {
// for option type, just a null ptr, whose type needs to be casted in codegen
// according to the type info attached in the ast
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
} else {
let id_str = id.to_string();

View File

@ -958,16 +958,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
let resolver = ctx.resolver.clone();
let val = resolver.get_symbol_value(*id, ctx).unwrap();
// if is tuple, need to deref it to handle tuple as value
if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = (
// if is option, need to cast pointer to handle None
match (
&*ctx.unifier.get_ty(expr.custom.unwrap()),
resolver
.get_symbol_value(*id, ctx)
.unwrap()
.to_basic_value_enum(ctx, generator)?,
) {
ctx.builder.build_load(ptr, "tup_val").into()
} else {
val
(TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => {
ctx.builder.build_load(ptr, "tup_val").into()
}
(TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr))
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => {
let actual_ptr_ty = ctx.get_llvm_type(
generator,
*params.iter().next().unwrap().1,
)
.ptr_type(AddressSpace::Generic);
ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into()
}
_ => val,
}
}
},

View File

@ -270,7 +270,7 @@ fn get_llvm_type<'ctx>(
let result = match &*ty_enum {
TObj { obj_id, fields, .. } => {
// check to avoid treating primitives other than Option as classes
if obj_id.0 <= 14 {
if obj_id.0 <= 10 {
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
{
(

View File

@ -54,6 +54,18 @@ pub enum RecordKey {
Int(i32),
}
impl Type {
ychenfo marked this conversation as resolved Outdated
Outdated
Review

Could you add a comment that documents what this does?

Could you add a comment that documents what this does?

comment added

comment added
// a wrapper function for cleaner code so that we don't need to
// write this long pattern matching just to get the field `obj_id`
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
*obj_id
} else {
unreachable!("expect a object type")
}
}
}
impl From<&RecordKey> for StrRef {
fn from(r: &RecordKey) -> Self {
match r {