Option type support #224
|
@ -11,7 +11,7 @@ from embedding_map import EmbeddingMap
|
|||
|
||||
__all__ = [
|
||||
"Kernel", "KernelInvariant", "virtual",
|
||||
"Option", "Some",
|
||||
"Option", "Some", "none",
|
||||
"round64", "floor64", "ceil64",
|
||||
"extern", "kernel", "portable", "nac3",
|
||||
"rpc", "ms", "us", "ns",
|
||||
|
@ -50,19 +50,20 @@ class Option(Generic[T]):
|
|||
|
||||
def __repr__(self) -> str:
|
||||
if self.is_none():
|
||||
return "Option(None)"
|
||||
return "none"
|
||||
ychenfo marked this conversation as resolved
Outdated
|
||||
else:
|
||||
return "Some({})".format(repr(self._nac3_option))
|
||||
sb10q
commented
Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening. Python would report a stack overflow error already, and from its backtrace it should be clear enough what is happening.
sb10q
commented
And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)". Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the You could put an underscore And by the way your proposal does not find all cases. For example, you could have option A containing option B, and option B containing option A, and it would overflow the stack and not print "Error(self recursion)".
Anyway those infinite recursions seem rare (you would have to try hard to shoot yourself in the foot by mutating the ``Option`` object) and produce a good enough backtrace without anything special.
You could put an underscore ``_nac3_option`` (see Python naming conventions) to highlight the fact that the user is not supposed to mutate the ``Option``.
ychenfo
commented
Thanks for pointing this out! I have force pushed to update the Thanks for pointing this out! I have force pushed to update the `__str__` and `__repr__`, remove the recursion check and use `_nac3_option`
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.is_none():
|
||||
return "None"
|
||||
return "none"
|
||||
else:
|
||||
return "Some({})".format(str(self._nac3_option))
|
||||
|
||||
def Some(v: T) -> Option[T]:
|
||||
return Option(v)
|
||||
|
||||
none = Option(None)
|
||||
|
||||
def round64(x):
|
||||
return round(x)
|
||||
|
|
|
@ -353,7 +353,6 @@ impl Nac3 {
|
|||
|
||||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||
let type_fn = builtins_mod.getattr("type").unwrap();
|
||||
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
||||
let typing_mod = PyModule::import(py, "typing").unwrap();
|
||||
let types_mod = PyModule::import(py, "types").unwrap();
|
||||
|
@ -376,7 +375,13 @@ impl Nac3 {
|
|||
get_attr_id(types_mod, "GenericAlias"),
|
||||
),
|
||||
none: id_fn
|
||||
.call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),))
|
||||
.call1((builtins_mod
|
||||
.getattr("globals")
|
||||
.unwrap()
|
||||
.call0()
|
||||
.unwrap()
|
||||
.get_item("none")
|
||||
.unwrap(),))
|
||||
.unwrap()
|
||||
.extract()
|
||||
.unwrap(),
|
||||
|
|
|
@ -282,24 +282,7 @@ impl InnerResolver {
|
|||
} else if ty_id == self.primitive_ids.option {
|
||||
Ok(Ok((primitives.option, false)))
|
||||
} else if ty_id == self.primitive_ids.none {
|
||||
if let TypeEnum::TObj { params, .. } =
|
||||
unifier.get_ty_immutable(primitives.option).as_ref()
|
||||
{
|
||||
let var_map = params
|
||||
.iter()
|
||||
.map(|(id_var, ty)| {
|
||||
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
|
||||
assert_eq!(*id, *id_var);
|
||||
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true)))
|
||||
} else {
|
||||
unreachable!("must be tobj")
|
||||
}
|
||||
unreachable!("none cannot be typeid")
|
||||
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
|
||||
let def = defs[def_id.0].read();
|
||||
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
|
||||
|
@ -597,14 +580,32 @@ impl InnerResolver {
|
|||
{
|
||||
let field_data = match obj.getattr("_nac3_option") {
|
||||
Ok(d) => d,
|
||||
// None should be already handled above
|
||||
// we use `none = Option(None)`, so the obj always have attr `_nac3_option`
|
||||
Err(_) => unreachable!("cannot be None")
|
||||
};
|
||||
let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?;
|
||||
let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if field_obj_id == zelf_obj_id {
|
||||
return Ok(Err("self recursive option type is not allowed".into()))
|
||||
// if is `none`
|
||||
let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
|
||||
if zelf_id == self.primitive_ids.none {
|
||||
if let TypeEnum::TObj { params, .. } =
|
||||
unifier.get_ty_immutable(primitives.option).as_ref()
|
||||
{
|
||||
let var_map = params
|
||||
.iter()
|
||||
.map(|(id_var, ty)| {
|
||||
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
|
||||
assert_eq!(*id, *id_var);
|
||||
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()))
|
||||
} else {
|
||||
unreachable!("must be tobj")
|
||||
}
|
||||
}
|
||||
|
||||
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
|
@ -845,36 +846,38 @@ impl InnerResolver {
|
|||
global.set_initializer(&val);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
} else if ty_id == self.primitive_ids.option {
|
||||
match self
|
||||
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!(
|
||||
"Error getting value of Option object: {}",
|
||||
e
|
||||
))
|
||||
})? {
|
||||
Some(v) => {
|
||||
let global_str = format!("{}_option", id);
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
if id == self.primitive_ids.none {
|
||||
// for option type, just a null ptr, whose type needs to be casted in codegen
|
||||
// according to the type info attached in the ast
|
||||
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
|
||||
} else {
|
||||
match self
|
||||
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
|
||||
.map_err(|e| {
|
||||
super::CompileError::new_err(format!(
|
||||
"Error getting value of Option object: {}",
|
||||
e
|
||||
))
|
||||
})? {
|
||||
Some(v) => {
|
||||
let global_str = format!("{}_option", id);
|
||||
{
|
||||
if self.global_value_ids.read().contains(&id) {
|
||||
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
|
||||
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
|
||||
});
|
||||
return Ok(Some(global.as_pointer_value().into()));
|
||||
} else {
|
||||
self.global_value_ids.write().insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
|
||||
global.set_initializer(&v);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
},
|
||||
None => Ok(None),
|
||||
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
|
||||
global.set_initializer(&v);
|
||||
Ok(Some(global.as_pointer_value().into()))
|
||||
},
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
} else if ty_id == self.primitive_ids.none {
|
||||
// for option type, just a null ptr, whose type needs to be casted in codegen
|
||||
// according to the type info attached in the ast
|
||||
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
|
||||
} else {
|
||||
let id_str = id.to_string();
|
||||
|
||||
|
|
|
@ -200,22 +200,6 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
|
|||
val
|
||||
}
|
||||
}
|
||||
Constant::None => {
|
||||
match (
|
||||
self.unifier.get_ty(ty).as_ref(),
|
||||
self.unifier.get_ty(self.primitives.option).as_ref(),
|
||||
) {
|
||||
(
|
||||
TypeEnum::TObj { obj_id, params, .. },
|
||||
TypeEnum::TObj { obj_id: opt_id, .. },
|
||||
) if *obj_id == *opt_id => self
|
||||
.get_llvm_type(generator, *params.iter().next().unwrap().1)
|
||||
.ptr_type(AddressSpace::Generic)
|
||||
.const_null()
|
||||
.into(),
|
||||
_ => unreachable!("must be option type"),
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
@ -951,6 +935,22 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
|
|||
let ty = expr.custom.unwrap();
|
||||
ctx.gen_const(generator, value, ty).into()
|
||||
}
|
||||
ExprKind::Name { id, .. } if id == &"none".into() => {
|
||||
match (
|
||||
ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(),
|
||||
ctx.unifier.get_ty(ctx.primitives.option).as_ref(),
|
||||
) {
|
||||
(
|
||||
TypeEnum::TObj { obj_id, params, .. },
|
||||
TypeEnum::TObj { obj_id: opt_id, .. },
|
||||
) if *obj_id == *opt_id => ctx
|
||||
.get_llvm_type(generator, *params.iter().next().unwrap().1)
|
||||
.ptr_type(AddressSpace::Generic)
|
||||
.const_null()
|
||||
.into(),
|
||||
_ => unreachable!("must be option type"),
|
||||
}
|
||||
}
|
||||
ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) {
|
||||
Some((ptr, None, _)) => ctx.builder.build_load(*ptr, "load").into(),
|
||||
Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()),
|
||||
|
|
|
@ -20,6 +20,8 @@ impl<'a> Inferencer<'a> {
|
|||
defined_identifiers: &mut HashSet<StrRef>,
|
||||
) -> Result<(), String> {
|
||||
match &pattern.node {
|
||||
ast::ExprKind::Name { id, .. } if id == &"none".into() =>
|
||||
Err(format!("cannot assign to a `none` (at {})", pattern.location)),
|
||||
ExprKind::Name { id, .. } => {
|
||||
if !defined_identifiers.contains(id) {
|
||||
defined_identifiers.insert(*id);
|
||||
|
@ -70,6 +72,9 @@ impl<'a> Inferencer<'a> {
|
|||
}
|
||||
match &expr.node {
|
||||
ExprKind::Name { id, .. } => {
|
||||
if id == &"none".into() {
|
||||
return Ok(());
|
||||
}
|
||||
self.should_have_value(expr)?;
|
||||
if !defined_identifiers.contains(id) {
|
||||
match self.function_data.resolver.get_symbol_type(
|
||||
|
|
|
@ -449,25 +449,47 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
|||
Some(self.infer_constant(value, &expr.location)?)
|
||||
}
|
||||
ast::ExprKind::Name { id, .. } => {
|
||||
if !self.defined_identifiers.contains(id) {
|
||||
match self.function_data.resolver.get_symbol_type(
|
||||
self.unifier,
|
||||
&self.top_level.definitions.read(),
|
||||
self.primitives,
|
||||
*id,
|
||||
) {
|
||||
Ok(_) => {
|
||||
self.defined_identifiers.insert(*id);
|
||||
}
|
||||
Err(e) => {
|
||||
return report_error(
|
||||
&format!("type error at identifier `{}` ({})", id, e),
|
||||
expr.location,
|
||||
);
|
||||
// the name `none` is special since it may have different types
|
||||
if id == &"none".into() {
|
||||
if let TypeEnum::TObj { params, .. } =
|
||||
self.unifier.get_ty_immutable(self.primitives.option).as_ref()
|
||||
{
|
||||
let var_map = params
|
||||
.iter()
|
||||
.map(|(id_var, ty)| {
|
||||
if let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) {
|
||||
assert_eq!(*id, *id_var);
|
||||
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0)
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
})
|
||||
.collect::<HashMap<_, _>>();
|
||||
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())
|
||||
} else {
|
||||
unreachable!("must be tobj")
|
||||
}
|
||||
} else {
|
||||
if !self.defined_identifiers.contains(id) {
|
||||
match self.function_data.resolver.get_symbol_type(
|
||||
self.unifier,
|
||||
&self.top_level.definitions.read(),
|
||||
self.primitives,
|
||||
*id,
|
||||
) {
|
||||
Ok(_) => {
|
||||
self.defined_identifiers.insert(*id);
|
||||
}
|
||||
Err(e) => {
|
||||
return report_error(
|
||||
&format!("type error at identifier `{}` ({})", id, e),
|
||||
expr.location,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(self.infer_identifier(*id)?)
|
||||
}
|
||||
Some(self.infer_identifier(*id)?)
|
||||
}
|
||||
ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?),
|
||||
ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?),
|
||||
|
@ -933,17 +955,8 @@ impl<'a> Inferencer<'a> {
|
|||
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
|
||||
}
|
||||
ast::Constant::Str(_) => Ok(self.primitives.str),
|
||||
ast::Constant::None => {
|
||||
let option_ty = self.primitives.option;
|
||||
let new_mapping = if let TypeEnum::TObj { params, .. } = &*self.unifier.get_ty_immutable(option_ty) {
|
||||
let (id, _) = params.iter().next().unwrap();
|
||||
// None can be Option[Any]
|
||||
vec![(*id, self.unifier.get_fresh_var(None, None).0)].into_iter().collect()
|
||||
} else {
|
||||
unreachable!("option must be tobj")
|
||||
};
|
||||
Ok(self.unifier.subst(option_ty, &new_mapping).unwrap())
|
||||
}
|
||||
ast::Constant::None
|
||||
=> report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc),
|
||||
_ => report_error("not supported", *loc),
|
||||
}
|
||||
}
|
||||
|
|
Just
none
is fine.