compiler: Factor rpc_tag() out of llvm_ir_generator

This commit is contained in:
David Nadlinger 2020-07-28 23:45:37 +01:00
parent e77c7d1c39
commit 9af6e5747d
3 changed files with 45 additions and 47 deletions

View File

@ -36,6 +36,47 @@ class TKeyword(types.TMono):
def is_keyword(typ):
return isinstance(typ, TKeyword)
# See rpc_proto.rs and comm_kernel.py:_{send,receive}_rpc_value.
def rpc_tag(typ, error_handler):
typ = typ.find()
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([rpc_tag(elt_type, error_handler)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
elif builtins.is_bool(typ):
return b"b"
elif builtins.is_int(typ, types.TValue(32)):
return b"i"
elif builtins.is_int(typ, types.TValue(64)):
return b"I"
elif builtins.is_float(typ):
return b"f"
elif builtins.is_str(typ):
return b"s"
elif builtins.is_bytes(typ):
return b"B"
elif builtins.is_bytearray(typ):
return b"A"
elif builtins.is_list(typ):
return b"l" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
elif builtins.is_array(typ):
return b"a" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
elif builtins.is_range(typ):
return b"r" + rpc_tag(builtins.get_iterable_elt(typ), error_handler)
elif is_keyword(typ):
return b"k" + rpc_tag(typ.params["value"], error_handler)
elif types.is_function(typ) or types.is_method(typ) or types.is_rpc(typ):
raise ValueError("RPC tag for functional value")
elif '__objectid__' in typ.attributes:
return b"O"
else:
error_handler(typ)
class Value:
"""
An SSA value that keeps track of its uses.

View File

@ -570,7 +570,7 @@ class LLVMIRGenerator:
if name == "__objectid__":
rpctag = b""
else:
rpctag = b"Os" + self._rpc_tag(typ, error_handler=rpc_tag_error) + b":n"
rpctag = b"Os" + ir.rpc_tag(typ, error_handler=rpc_tag_error) + b":n"
llrpcattrinit = ll.Constant(llrpcattrty, [
ll.Constant(lli32, offset),
@ -1310,49 +1310,6 @@ class LLVMIRGenerator:
return llfun, list(llargs)
# See session.c:{send,receive}_rpc_value and comm_generic.py:_{send,receive}_rpc_value.
def _rpc_tag(self, typ, error_handler):
typ = typ.find()
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, error_handler)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
elif builtins.is_bool(typ):
return b"b"
elif builtins.is_int(typ, types.TValue(32)):
return b"i"
elif builtins.is_int(typ, types.TValue(64)):
return b"I"
elif builtins.is_float(typ):
return b"f"
elif builtins.is_str(typ):
return b"s"
elif builtins.is_bytes(typ):
return b"B"
elif builtins.is_bytearray(typ):
return b"A"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler)
elif builtins.is_array(typ):
return b"a" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler)
elif ir.is_keyword(typ):
return b"k" + self._rpc_tag(typ.params["value"],
error_handler)
elif types.is_function(typ) or types.is_method(typ) or types.is_rpc(typ):
raise ValueError("RPC tag for functional value")
elif '__objectid__' in typ.attributes:
return b"O"
else:
error_handler(typ)
def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(lli32, fun_type.service)
@ -1370,7 +1327,7 @@ class LLVMIRGenerator:
{"type": printer.name(arg.type)},
arg.loc)
self.engine.process(diag)
tag += self._rpc_tag(arg.type, arg_error_handler)
tag += ir.rpc_tag(arg.type, arg_error_handler)
tag += b":"
def ret_error_handler(typ):
@ -1384,7 +1341,7 @@ class LLVMIRGenerator:
{"type": printer.name(fun_type.ret)},
fun_loc)
self.engine.process(diag)
tag += self._rpc_tag(fun_type.ret, ret_error_handler)
tag += ir.rpc_tag(fun_type.ret, ret_error_handler)
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
lltagptr = self.llbuilder.alloca(lltag.type)

View File

@ -244,7 +244,7 @@ class CommKernel:
_rpc_sentinel = object()
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
# See rpc_proto.rs and compiler/ir.py:rpc_tag.
def _receive_rpc_value(self, embedding_map):
tag = chr(self._read_int8())
if tag == "\x00":