Option type support #224
@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap
|
||||
|
||||
__all__ = [
|
||||
"Kernel", "KernelInvariant", "virtual",
|
||||
"Option", "Some",
|
||||
"round64", "floor64", "ceil64",
|
||||
"extern", "kernel", "portable", "nac3",
|
||||
"rpc", "ms", "us", "ns",
|
||||
@ -32,6 +33,36 @@ class KernelInvariant(Generic[T]):
|
||||
class virtual(Generic[T]):
|
||||
pass
|
||||
|
||||
class Option(Generic[T]):
|
||||
_nac3_option: T
|
||||
|
||||
def __init__(self, v: T):
|
||||
self._nac3_option = v
|
||||
|
||||
def is_none(self):
|
||||
return self._nac3_option is None
|
||||
|
||||
def is_some(self):
|
||||
|
||||
return not self.is_none()
|
||||
|
||||
def unwrap(self):
|
||||
return self._nac3_option
|
||||
ychenfo marked this conversation as resolved
Outdated
sb10q
commented
You want to raise an exception if it's You want to raise an exception if it's ``None`` to match the device behavior.
ychenfo
commented
thanks! now it will raise a thanks! now it will raise a `ValueError("unwrap on none")`
sb10q
commented
I would suggest a dedicated exception type. I would suggest a dedicated exception type.
ychenfo
commented
added a new dedicated exception added a new dedicated exception `UnwrapNoneError`
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.is_none():
|
||||
return "Option(None)"
|
||||
ychenfo marked this conversation as resolved
Outdated
sb10q
commented
Just Just ``none`` is fine.
|
||||
else:
|
||||
return "Some({})".format(repr(self._nac3_option))
|
||||
sb10q
commented
Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening. Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening.
sb10q
commented
And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)". Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the You could put an underscore And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)".
Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the ``Option`` object) and produce a good enough backtrace without anything special.
You could put an underscore ``_nac3_option`` (see Python naming conventions) to highlight the fact that the user is not supposed to mutate the ``Option``.
ychenfo
commented
Thanks for pointing this out! I have force pushed to update the Thanks for pointing this out! I have force pushed to update the `__str__` and `__repr__`, remove the recursion check and use `_nac3_option`
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_none():
|
||||
return "None"
|
||||
else:
|
||||
return "Some({})".format(str(self._nac3_option))
|
||||
|
||||
def Some(v: T) -> Option[T]:
|
||||
return Option(v)
|
||||
|
||||
|
||||
def round64(x):
|
||||
return round(x)
|
||||
|
@ -71,6 +71,7 @@ pub struct PrimitivePythonId {
|
||||
exception: u64,
|
||||
generic_alias: (u64, u64),
|
||||
virtual_id: u64,
|
||||
option: u64,
|
||||
}
|
||||
|
||||
type TopLevelComponent = (Stmt, String, PyObject);
|
||||
@ -352,6 +353,7 @@ impl Nac3 {
|
||||
|
||||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||
let type_fn = builtins_mod.getattr("type").unwrap();
|
||||
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
||||
let typing_mod = PyModule::import(py, "typing").unwrap();
|
||||
let types_mod = PyModule::import(py, "types").unwrap();
|
||||
@ -373,7 +375,11 @@ impl Nac3 {
|
||||
get_attr_id(typing_mod, "_GenericAlias"),
|
||||
get_attr_id(types_mod, "GenericAlias"),
|
||||
),
|
||||
none: get_attr_id(builtins_mod, "None"),
|
||||
none: id_fn
|
||||
.call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),))
|
||||
.unwrap()
|
||||
.extract()
|
||||
.unwrap(),
|
||||
typevar: get_attr_id(typing_mod, "TypeVar"),
|
||||
int: get_attr_id(builtins_mod, "int"),
|
||||
int32: get_attr_id(numpy_mod, "int32"),
|
||||
@ -385,6 +391,17 @@ impl Nac3 {
|
||||
list: get_attr_id(builtins_mod, "list"),
|
||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||
exception: get_attr_id(builtins_mod, "Exception"),
|
||||
option: id_fn
|
||||
.call1((builtins_mod
|
||||
.getattr("globals")
|
||||
.unwrap()
|
||||
.call0()
|
||||
.unwrap()
|
||||
.get_item("Option")
|
||||
.unwrap(),))
|
||||
.unwrap()
|
||||
.extract()
|
||||
.unwrap(),
|
||||
};
|
||||
|
||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||
|
@ -279,6 +279,27 @@ impl InnerResolver {
|
||||
} else if ty_id == self.primitive_ids.tuple {
|
||||
// do not handle type var param and concrete check here
|
||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||
} else if ty_id == self.primitive_ids.option {
|
||||
Ok(Ok((primitives.option, false)))
|
||||
} else if ty_id == self.primitive_ids.none {
|
||||
if let TypeEnum::TObj { params, .. } =
|
||||
unifier.get_ty_immutable(primitives.option).as_ref()
|
||||
{
|
||||
let var_map = params
|
||||
.iter()
|
||||
.map(|(id_var, ty)| {
|
||||
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
|
||||
assert_eq!(*id, *id_var);
|
||||
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true)))
|
||||
} else {
|
||||
unreachable!("must be tobj")
|
||||
}
|
||||
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
|
||||
let def = defs[def_id.0].read();
|
||||
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
|
||||
@ -569,6 +590,34 @@ impl InnerResolver {
|
||||
let types = types?;
|
||||
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
|
||||
}
|
||||
// special handling for option type since its class member layout in python side
|
||||
// is special and cannot be mapped directly to a nac3 type as below
|
||||
(TypeEnum::TObj { obj_id, params, .. }, false)
|
||||
if *obj_id == primitives.option.get_obj_id(unifier) =>
|
||||
{
|
||||
let field_data = match obj.getattr("_nac3_option") {
|
||||
Ok(d) => d,
|
||||
// None should be already handled above
|
||||
Err(_) => unreachable!("cannot be None")
|
||||
};
|
||||
let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?;
|
||||
let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if field_obj_id == zelf_obj_id {
|
||||
return Ok(Err("self recursive option type is not allowed".into()))
|
||||
}
|
||||
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
return Ok(Err(format!(
|
||||
"error when getting type of the option object ({})",
|
||||
e
|
||||
)))
|
||||
}
|
||||
};
|
||||
let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect();
|
||||
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
|
||||
Ok(Ok(res))
|
||||
}
|
||||
(TypeEnum::TObj { params, fields, .. }, false) => {
|
||||
let var_map = params
|
||||
.iter()
|
||||
@ -795,6 +844,37 @@ impl InnerResolver {
|
||||
let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
|
||||
global.set_initializer(&val);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.option {
|
||||
match self
|
||||
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!(
|
||||
"Error getting value of Option object: {}",
|
||||
e
|
||||
))
|
||||
})? {
|
||||
Some(v) => {
|
||||
let global_str = format!("{}_option", id);
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
}
|
||||
}
|
||||
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
|
||||
global.set_initializer(&v);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
} else if ty_id == self.primitive_ids.none {
|
||||
// for option type, just a null ptr, whose type needs to be casted in codegen
|
||||
// according to the type info attached in the ast
|
||||
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
|
||||
} else {
|
||||
let id_str = id.to_string();
|
||||
|
||||
|
@ -958,16 +958,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
|
||||
let resolver = ctx.resolver.clone();
|
||||
let val = resolver.get_symbol_value(*id, ctx).unwrap();
|
||||
// if is tuple, need to deref it to handle tuple as value
|
||||
if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = (
|
||||
// if is option, need to cast pointer to handle None
|
||||
match (
|
||||
&*ctx.unifier.get_ty(expr.custom.unwrap()),
|
||||
resolver
|
||||
.get_symbol_value(*id, ctx)
|
||||
.unwrap()
|
||||
.to_basic_value_enum(ctx, generator)?,
|
||||
) {
|
||||
ctx.builder.build_load(ptr, "tup_val").into()
|
||||
} else {
|
||||
val
|
||||
(TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => {
|
||||
ctx.builder.build_load(ptr, "tup_val").into()
|
||||
}
|
||||
(TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr))
|
||||
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => {
|
||||
let actual_ptr_ty = ctx.get_llvm_type(
|
||||
generator,
|
||||
*params.iter().next().unwrap().1,
|
||||
)
|
||||
.ptr_type(AddressSpace::Generic);
|
||||
ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into()
|
||||
}
|
||||
_ => val,
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -270,7 +270,7 @@ fn get_llvm_type<'ctx>(
|
||||
let result = match &*ty_enum {
|
||||
TObj { obj_id, fields, .. } => {
|
||||
// check to avoid treating primitives other than Option as classes
|
||||
if obj_id.0 <= 14 {
|
||||
if obj_id.0 <= 10 {
|
||||
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
|
||||
{
|
||||
(
|
||||
|
@ -54,6 +54,18 @@ pub enum RecordKey {
|
||||
Int(i32),
|
||||
}
|
||||
|
||||
impl Type {
|
||||
ychenfo marked this conversation as resolved
Outdated
sb10q
commented
Could you add a comment that documents what this does? Could you add a comment that documents what this does?
ychenfo
commented
comment added comment added
|
||||
// a wrapper function for cleaner code so that we don't need to
|
||||
// write this long pattern matching just to get the field `obj_id`
|
||||
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
|
||||
*obj_id
|
||||
} else {
|
||||
unreachable!("expect a object type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&RecordKey> for StrRef {
|
||||
fn from(r: &RecordKey) -> Self {
|
||||
match r {
|
||||
|
Blank lines,
__str__
/__repr__
I am not fully sure the best content to put into the returned the string.. currently a simple implementation based on something like this, and
__str__
will return the same thing as__repr__
"The str() function returns a user-friendly description of an object. The repr() method returns a developer-friendly string representation of an object."
I propose:
str:
Some(x)
/None
(like Rust)repr:
Some(x)
/Option(None)
(to clarify the difference with the PythonNone
)