nac3artiq/core: host option object support

This commit is contained in:
ychenfo 2022-03-17 00:39:21 +08:00
parent 34b460757b
commit 29da2d1a9a
6 changed files with 157 additions and 6 deletions

View File

@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap
__all__ = [ __all__ = [
"Kernel", "KernelInvariant", "virtual", "Kernel", "KernelInvariant", "virtual",
"Option", "Some",
"round64", "floor64", "ceil64", "round64", "floor64", "ceil64",
"extern", "kernel", "portable", "nac3", "extern", "kernel", "portable", "nac3",
"rpc", "ms", "us", "ns", "rpc", "ms", "us", "ns",
@ -32,6 +33,36 @@ class KernelInvariant(Generic[T]):
class virtual(Generic[T]): class virtual(Generic[T]):
pass pass
class Option(Generic[T]):
_nac3_option: T
def __init__(self, v: T):
self._nac3_option = v
def is_none(self):
return self._nac3_option is None
def is_some(self):
return not self.is_none()
def unwrap(self):
return self._nac3_option
def __repr__(self) -> str:
if self.is_none():
return "Option(None)"
else:
return "Some({})".format(repr(self._nac3_option))
def __str__(self) -> str:
if self.is_none():
return "None"
else:
return "Some({})".format(str(self._nac3_option))
def Some(v: T) -> Option[T]:
return Option(v)
def round64(x): def round64(x):
return round(x) return round(x)

View File

@ -71,6 +71,7 @@ pub struct PrimitivePythonId {
exception: u64, exception: u64,
generic_alias: (u64, u64), generic_alias: (u64, u64),
virtual_id: u64, virtual_id: u64,
option: u64,
} }
type TopLevelComponent = (Stmt, String, PyObject); type TopLevelComponent = (Stmt, String, PyObject);
@ -352,6 +353,7 @@ impl Nac3 {
let builtins_mod = PyModule::import(py, "builtins").unwrap(); let builtins_mod = PyModule::import(py, "builtins").unwrap();
let id_fn = builtins_mod.getattr("id").unwrap(); let id_fn = builtins_mod.getattr("id").unwrap();
let type_fn = builtins_mod.getattr("type").unwrap();
let numpy_mod = PyModule::import(py, "numpy").unwrap(); let numpy_mod = PyModule::import(py, "numpy").unwrap();
let typing_mod = PyModule::import(py, "typing").unwrap(); let typing_mod = PyModule::import(py, "typing").unwrap();
let types_mod = PyModule::import(py, "types").unwrap(); let types_mod = PyModule::import(py, "types").unwrap();
@ -373,7 +375,11 @@ impl Nac3 {
get_attr_id(typing_mod, "_GenericAlias"), get_attr_id(typing_mod, "_GenericAlias"),
get_attr_id(types_mod, "GenericAlias"), get_attr_id(types_mod, "GenericAlias"),
), ),
none: get_attr_id(builtins_mod, "None"), none: id_fn
.call1((type_fn.call1((builtins_mod.getattr("None").unwrap(),)).unwrap(),))
.unwrap()
.extract()
.unwrap(),
typevar: get_attr_id(typing_mod, "TypeVar"), typevar: get_attr_id(typing_mod, "TypeVar"),
int: get_attr_id(builtins_mod, "int"), int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"), int32: get_attr_id(numpy_mod, "int32"),
@ -385,6 +391,17 @@ impl Nac3 {
list: get_attr_id(builtins_mod, "list"), list: get_attr_id(builtins_mod, "list"),
tuple: get_attr_id(builtins_mod, "tuple"), tuple: get_attr_id(builtins_mod, "tuple"),
exception: get_attr_id(builtins_mod, "Exception"), exception: get_attr_id(builtins_mod, "Exception"),
option: id_fn
.call1((builtins_mod
.getattr("globals")
.unwrap()
.call0()
.unwrap()
.get_item("Option")
.unwrap(),))
.unwrap()
.extract()
.unwrap(),
}; };
let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap();

View File

@ -279,6 +279,27 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
// do not handle type var param and concrete check here // do not handle type var param and concrete check here
Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false)))
} else if ty_id == self.primitive_ids.option {
Ok(Ok((primitives.option, false)))
} else if ty_id == self.primitive_ids.none {
if let TypeEnum::TObj { params, .. } =
unifier.get_ty_immutable(primitives.option).as_ref()
{
let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
Ok(Ok((unifier.subst(primitives.option, &var_map).unwrap(), true)))
} else {
unreachable!("must be tobj")
}
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() {
let def = defs[def_id.0].read(); let def = defs[def_id.0].read();
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
@ -569,6 +590,34 @@ impl InnerResolver {
let types = types?; let types = types?;
Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types })))
} }
// special handling for option type since its class member layout in python side
// is special and cannot be mapped directly to a nac3 type as below
(TypeEnum::TObj { obj_id, params, .. }, false)
if *obj_id == primitives.option.get_obj_id(unifier) =>
{
let field_data = match obj.getattr("_nac3_option") {
Ok(d) => d,
// None should be already handled above
Err(_) => unreachable!("cannot be None")
};
let field_obj_id: u64 = self.helper.id_fn.call1(py, (field_data,))?.extract(py)?;
let zelf_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if field_obj_id == zelf_obj_id {
return Ok(Err("self recursive option type is not allowed".into()))
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of the option object ({})",
e
)))
}
};
let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect();
let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty);
Ok(Ok(res))
}
(TypeEnum::TObj { params, fields, .. }, false) => { (TypeEnum::TObj { params, fields, .. }, false) => {
let var_map = params let var_map = params
.iter() .iter()
@ -795,6 +844,37 @@ impl InnerResolver {
let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val); global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.option {
match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting value of Option object: {}",
e
))
})? {
Some(v) => {
let global_str = format!("{}_option", id);
{
if self.global_value_ids.read().contains(&id) {
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str)
});
return Ok(Some(global.as_pointer_value().into()));
} else {
self.global_value_ids.write().insert(id);
}
}
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str);
global.set_initializer(&v);
Ok(Some(global.as_pointer_value().into()))
},
None => Ok(None),
}
} else if ty_id == self.primitive_ids.none {
// for option type, just a null ptr, whose type needs to be casted in codegen
// according to the type info attached in the ast
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
} else { } else {
let id_str = id.to_string(); let id_str = id.to_string();

View File

@ -958,16 +958,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
let resolver = ctx.resolver.clone(); let resolver = ctx.resolver.clone();
let val = resolver.get_symbol_value(*id, ctx).unwrap(); let val = resolver.get_symbol_value(*id, ctx).unwrap();
// if is tuple, need to deref it to handle tuple as value // if is tuple, need to deref it to handle tuple as value
if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = ( // if is option, need to cast pointer to handle None
match (
&*ctx.unifier.get_ty(expr.custom.unwrap()), &*ctx.unifier.get_ty(expr.custom.unwrap()),
resolver resolver
.get_symbol_value(*id, ctx) .get_symbol_value(*id, ctx)
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator)?, .to_basic_value_enum(ctx, generator)?,
) { ) {
ctx.builder.build_load(ptr, "tup_val").into() (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => {
} else { ctx.builder.build_load(ptr, "tup_val").into()
val }
(TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr))
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => {
let actual_ptr_ty = ctx.get_llvm_type(
generator,
*params.iter().next().unwrap().1,
)
.ptr_type(AddressSpace::Generic);
ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into()
}
_ => val,
} }
} }
}, },

View File

@ -270,7 +270,7 @@ fn get_llvm_type<'ctx>(
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating primitives other than Option as classes // check to avoid treating primitives other than Option as classes
if obj_id.0 <= 14 { if obj_id.0 <= 10 {
match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref())
{ {
( (

View File

@ -54,6 +54,18 @@ pub enum RecordKey {
Int(i32), Int(i32),
} }
impl Type {
// a wrapper function for cleaner code so that we don't need to
// write this long pattern matching just to get the field `obj_id`
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
*obj_id
} else {
unreachable!("expect a object type")
}
}
}
impl From<&RecordKey> for StrRef { impl From<&RecordKey> for StrRef {
fn from(r: &RecordKey) -> Self { fn from(r: &RecordKey) -> Self {
match r { match r {