forked from M-Labs/nac3
nac3artiq/core: host option object support
This commit is contained in:
parent
d86a75bf0e
commit
7db5909f62
@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Kernel", "KernelInvariant", "virtual",
|
"Kernel", "KernelInvariant", "virtual",
|
||||||
|
"Option", "Some",
|
||||||
"round64", "floor64", "ceil64",
|
"round64", "floor64", "ceil64",
|
||||||
"extern", "kernel", "portable", "nac3",
|
"extern", "kernel", "portable", "nac3",
|
||||||
"rpc", "ms", "us", "ns",
|
"rpc", "ms", "us", "ns",
|
||||||
@ -32,6 +33,36 @@ class KernelInvariant(Generic[T]):
|
|||||||
class virtual(Generic[T]):
|
class virtual(Generic[T]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class Option(Generic[T]):
|
||||||
|
_nac3_option: T
|
||||||
|
|
||||||
|
def __init__(self, v: T):
|
||||||
|
self._nac3_option = v
|
||||||
|
|
||||||
|
def is_none(self):
|
||||||
|
return self._nac3_option is None
|
||||||
|
|
||||||
|
def is_some(self):
|
||||||
|
return not self.is_none()
|
||||||
|
|
||||||
|
def unwrap(self):
|
||||||
|
return self._nac3_option
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.is_none():
|
||||||
|
return "Option(None)"
|
||||||
|
else:
|
||||||
|
return "Some({})".format(repr(self._nac3_option))
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
if self.is_none():
|
||||||
|
return "None"
|
||||||
|
else:
|
||||||
|
return "Some({})".format(str(self._nac3_option))
|
||||||
|
|
||||||
|
def Some(v: T) -> Option[T]:
|
||||||
|
return Option(v)
|
||||||
|
|
||||||
|
|
||||||
def round64(x):
|
def round64(x):
|
||||||
return round(x)
|
return round(x)
|
||||||
|
@ -71,6 +71,7 @@ pub struct PrimitivePythonId {
|
|||||||
exception: u64,
|
exception: u64,
|
||||||
generic_alias: (u64, u64),
|
generic_alias: (u64, u64),
|
||||||
virtual_id: u64,
|
virtual_id: u64,
|
||||||
|
option: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
type TopLevelComponent = (Stmt, String, PyObject);
|
type TopLevelComponent = (Stmt, String, PyObject);
|
||||||
@ -351,6 +352,7 @@ impl Nac3 {
|
|||||||
|
|
||||||
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
let builtins_mod = PyModule::import(py, "builtins").unwrap();
|
||||||
let id_fn = builtins_mod.getattr("id").unwrap();
|
let id_fn = builtins_mod.getattr("id").unwrap();
|
||||||
|
let type_fn = builtins_mod.getattr("type").unwrap();
|
||||||
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
let numpy_mod = PyModule::import(py, "numpy").unwrap();
|
||||||
let typing_mod = PyModule::import(py, "typing").unwrap();
|
let typing_mod = PyModule::import(py, "typing").unwrap();
|
||||||
let types_mod = PyModule::import(py, "types").unwrap();
|
let types_mod = PyModule::import(py, "types").unwrap();
|
||||||
@ -372,7 +374,11 @@ impl Nac3 {
|
|||||||
get_attr_id(typing_mod, "_GenericAlias"),
|
get_attr_id(typing_mod, "_GenericAlias"),
|
||||||
get_attr_id(types_mod, "GenericAlias"),
|
get_attr_id(types_mod, "GenericAlias"),
|
||||||
),
|
),
|
||||||
none: get_attr_id(builtins_mod, "None"),
|
none: id_fn
|
||||||
|
.call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),))
|
||||||
|
.unwrap()
|
||||||
|
.extract()
|
||||||
|
.unwrap(),
|
||||||
typevar: get_attr_id(typing_mod, "TypeVar"),
|
typevar: get_attr_id(typing_mod, "TypeVar"),
|
||||||
int: get_attr_id(builtins_mod, "int"),
|
int: get_attr_id(builtins_mod, "int"),
|
||||||
int32: get_attr_id(numpy_mod, "int32"),
|
int32: get_attr_id(numpy_mod, "int32"),
|
||||||
@ -384,6 +390,17 @@ impl Nac3 {
|
|||||||
list: get_attr_id(builtins_mod, "list"),
|
list: get_attr_id(builtins_mod, "list"),
|
||||||
tuple: get_attr_id(builtins_mod, "tuple"),
|
tuple: get_attr_id(builtins_mod, "tuple"),
|
||||||
exception: get_attr_id(builtins_mod, "Exception"),
|
exception: get_attr_id(builtins_mod, "Exception"),
|
||||||
|
option: id_fn
|
||||||
|
.call1((builtins_mod
|
||||||
|
.getattr("globals")
|
||||||
|
.unwrap()
|
||||||
|
.call0()
|
||||||
|
.unwrap()
|
||||||
|
.get_item("Option")
|
||||||
|
.unwrap(),))
|
||||||
|
.unwrap()
|
||||||
|
.extract()
|
||||||
|
.unwrap(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();
|
||||||
|
@ -260,6 +260,27 @@ impl InnerResolver {
|
|||||||
} else if ty_id == self.primitive_ids.tuple {
|
} else if ty_id == self.primitive_ids.tuple {
|
||||||
// do not handle type var param and concrete check here
|
// do not handle type var param and concrete check here
|
||||||
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
|
||||||
|
} else if ty_id == self.primitive_ids.option {
|
||||||
|
Ok(Ok((primitives.option, false)))
|
||||||
|
} else if ty_id == self.primitive_ids.none {
|
||||||
|
if let TypeEnum::TObj { params, .. } =
|
||||||
|
unifier.get_ty_immutable(primitives.option).as_ref()
|
||||||
|
{
|
||||||
|
let var_map = params
|
||||||
|
.iter()
|
||||||
|
.map(|(id_var, ty)| {
|
||||||
|
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
|
||||||
|
assert_eq!(*id, *id_var);
|
||||||
|
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<HashMap<_, _>>();
|
||||||
|
Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true)))
|
||||||
|
} else {
|
||||||
|
unreachable!("must be tobj")
|
||||||
|
}
|
||||||
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
|
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
|
||||||
let def = defs[def_id.0].read();
|
let def = defs[def_id.0].read();
|
||||||
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
|
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
|
||||||
@ -538,6 +559,34 @@ impl InnerResolver {
|
|||||||
let types = types?;
|
let types = types?;
|
||||||
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
|
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
|
||||||
}
|
}
|
||||||
|
// special handling for option type since its class member layout in python side
|
||||||
|
// is special and cannot be mapped directly to a nac3 type as below
|
||||||
|
(TypeEnum::TObj { obj_id, params, .. }, false)
|
||||||
|
if *obj_id == primitives.option.get_obj_id(unifier) =>
|
||||||
|
{
|
||||||
|
let field_data = match obj.getattr("_nac3_option") {
|
||||||
|
Ok(d) => d,
|
||||||
|
// None should be already handled above
|
||||||
|
Err(_) => unreachable!("cannot be None")
|
||||||
|
};
|
||||||
|
let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?;
|
||||||
|
let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
|
||||||
|
if field_obj_id == zelf_obj_id {
|
||||||
|
return Ok(Err("self recursive option type is not allowed".into()))
|
||||||
|
}
|
||||||
|
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(Err(format!(
|
||||||
|
"error when getting type of the option object ({})",
|
||||||
|
e
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect();
|
||||||
|
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
|
||||||
|
Ok(Ok(res))
|
||||||
|
}
|
||||||
(TypeEnum::TObj { params: var_map, fields, .. }, false) => {
|
(TypeEnum::TObj { params: var_map, fields, .. }, false) => {
|
||||||
self.pyid_to_type.write().insert(ty_id, extracted_ty);
|
self.pyid_to_type.write().insert(ty_id, extracted_ty);
|
||||||
let mut instantiate_obj = || {
|
let mut instantiate_obj = || {
|
||||||
@ -756,6 +805,37 @@ impl InnerResolver {
|
|||||||
let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
|
let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
|
||||||
global.set_initializer(&val);
|
global.set_initializer(&val);
|
||||||
Ok(Some(global.as_pointer_value().into()))
|
Ok(Some(global.as_pointer_value().into()))
|
||||||
|
} else if ty_id == self.primitive_ids.option {
|
||||||
|
match self
|
||||||
|
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
|
||||||
|
.map_err(|e| {
|
||||||
|
super::CompileError::new_err(format!(
|
||||||
|
"Error getting value of Option object: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
})? {
|
||||||
|
Some(v) => {
|
||||||
|
let global_str = format!("{}_option", id);
|
||||||
|
{
|
||||||
|
if self.global_value_ids.read().contains(&id) {
|
||||||
|
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
|
||||||
|
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
|
||||||
|
});
|
||||||
|
return Ok(Some(global.as_pointer_value().into()));
|
||||||
|
} else {
|
||||||
|
self.global_value_ids.write().insert(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
|
||||||
|
global.set_initializer(&v);
|
||||||
|
Ok(Some(global.as_pointer_value().into()))
|
||||||
|
},
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
} else if ty_id == self.primitive_ids.none {
|
||||||
|
// for option type, just a null ptr, whose type needs to be casted in codegen
|
||||||
|
// according to the type info attached in the ast
|
||||||
|
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
|
||||||
} else {
|
} else {
|
||||||
let id_str = id.to_string();
|
let id_str = id.to_string();
|
||||||
|
|
||||||
@ -900,7 +980,15 @@ impl SymbolResolver for Resolver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Ok(t) = sym_ty {
|
if let Ok(t) = sym_ty {
|
||||||
self.0.pyid_to_type.write().insert(*id, t);
|
// do not cache for option type since None can have same pyid but different type
|
||||||
|
match unifier.get_ty_immutable(t).as_ref() {
|
||||||
|
TypeEnum::TObj { obj_id, .. }
|
||||||
|
if *obj_id == primitives.option.get_obj_id(unifier) =>
|
||||||
|
{
|
||||||
|
None
|
||||||
|
}
|
||||||
|
_ => self.0.pyid_to_type.write().insert(*id, t),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
Ok(sym_ty)
|
Ok(sym_ty)
|
||||||
})
|
})
|
||||||
|
@ -952,16 +952,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
|
|||||||
let resolver = ctx.resolver.clone();
|
let resolver = ctx.resolver.clone();
|
||||||
let val = resolver.get_symbol_value(*id, ctx).unwrap();
|
let val = resolver.get_symbol_value(*id, ctx).unwrap();
|
||||||
// if is tuple, need to deref it to handle tuple as value
|
// if is tuple, need to deref it to handle tuple as value
|
||||||
if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = (
|
// if is option, need to cast pointer to handle None
|
||||||
|
match (
|
||||||
&*ctx.unifier.get_ty(expr.custom.unwrap()),
|
&*ctx.unifier.get_ty(expr.custom.unwrap()),
|
||||||
resolver
|
resolver
|
||||||
.get_symbol_value(*id, ctx)
|
.get_symbol_value(*id, ctx)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator)?,
|
.to_basic_value_enum(ctx, generator)?,
|
||||||
) {
|
) {
|
||||||
ctx.builder.build_load(ptr, "tup_val").into()
|
(TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => {
|
||||||
} else {
|
ctx.builder.build_load(ptr, "tup_val").into()
|
||||||
val
|
}
|
||||||
|
(TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr))
|
||||||
|
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => {
|
||||||
|
let actual_ptr_ty = ctx.get_llvm_type(
|
||||||
|
generator,
|
||||||
|
*params.iter().next().unwrap().1,
|
||||||
|
)
|
||||||
|
.ptr_type(AddressSpace::Generic);
|
||||||
|
ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into()
|
||||||
|
}
|
||||||
|
_ => val,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -270,7 +270,7 @@ fn get_llvm_type<'ctx>(
|
|||||||
let result = match &*ty_enum {
|
let result = match &*ty_enum {
|
||||||
TObj { obj_id, fields, .. } => {
|
TObj { obj_id, fields, .. } => {
|
||||||
// check to avoid treating primitives other than Option as classes
|
// check to avoid treating primitives other than Option as classes
|
||||||
if obj_id.0 <= 14 {
|
if obj_id.0 <= 10 {
|
||||||
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
|
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
|
||||||
{
|
{
|
||||||
(
|
(
|
||||||
|
@ -54,6 +54,18 @@ pub enum RecordKey {
|
|||||||
Int(i32),
|
Int(i32),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Type {
|
||||||
|
// a wrapper function for cleaner code so that we don't need to
|
||||||
|
// write this long pattern matching just to get the field `obj_id`
|
||||||
|
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
|
||||||
|
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
|
||||||
|
*obj_id
|
||||||
|
} else {
|
||||||
|
unreachable!("expect a object type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl From<&RecordKey> for StrRef {
|
impl From<&RecordKey> for StrRef {
|
||||||
fn from(r: &RecordKey) -> Self {
|
fn from(r: &RecordKey) -> Self {
|
||||||
match r {
|
match r {
|
||||||
|
Loading…
Reference in New Issue
Block a user