mirror of https://github.com/m-labs/artiq.git
protocols/pipe_ipc: implement AsyncioChildComm for Unix (affected by python/asyncio#314)
This commit is contained in:
parent
8befc6a8fc
commit
dcea6780c6
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import asyncio
|
||||
from asyncio.streams import FlowControlMixin
|
||||
|
||||
|
||||
class _BaseIO:
|
||||
|
@ -20,6 +21,22 @@ class _BaseIO:
|
|||
|
||||
|
||||
if os.name != "nt":
|
||||
async def _fds_to_asyncio(rfd, wfd, loop):
|
||||
reader = asyncio.StreamReader(loop=loop)
|
||||
reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
|
||||
|
||||
wf = open(wfd, "wb", 0)
|
||||
transport, protocol = await loop.connect_write_pipe(
|
||||
FlowControlMixin, wf)
|
||||
writer = asyncio.StreamWriter(transport, protocol,
|
||||
None, loop)
|
||||
|
||||
rf = open(rfd, "rb", 0)
|
||||
await loop.connect_read_pipe(lambda: reader_protocol, rf)
|
||||
|
||||
return reader, writer
|
||||
|
||||
|
||||
class AsyncioParentComm(_BaseIO):
|
||||
def __init__(self):
|
||||
self.c_rfd, self.p_wfd = os.pipe()
|
||||
|
@ -35,20 +52,19 @@ if os.name != "nt":
|
|||
os.close(self.c_rfd)
|
||||
os.close(self.c_wfd)
|
||||
|
||||
pipe = open(self.p_rfd, "rb", 0)
|
||||
self.reader = asyncio.StreamReader(loop=loop)
|
||||
def factory():
|
||||
return asyncio.StreamReaderProtocol(self.reader, loop=loop)
|
||||
await loop.connect_read_pipe(factory, pipe)
|
||||
self.reader, self.writer = await _fds_to_asyncio(
|
||||
self.p_rfd, self.p_wfd, loop)
|
||||
|
||||
pipe = open(self.p_wfd, "wb", 0)
|
||||
transport, protocol = await loop.connect_write_pipe(
|
||||
asyncio.Protocol, pipe)
|
||||
self.writer = asyncio.StreamWriter(transport, protocol,
|
||||
None, loop)
|
||||
|
||||
class AsyncioChildComm(_BaseIO):
|
||||
pass
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
|
||||
async def connect(self):
|
||||
rfd, wfd = self.address.split(",", maxsplit=1)
|
||||
self.reader, self.writer = await _fds_to_asyncio(
|
||||
int(rfd), int(wfd), asyncio.get_event_loop())
|
||||
|
||||
|
||||
class ChildComm:
|
||||
def __init__(self, address):
|
||||
|
@ -82,10 +98,10 @@ else: # windows
|
|||
async def connect(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.reader = asyncio.StreamReader(loop=loop)
|
||||
def factory():
|
||||
return asyncio.StreamReaderProtocol(self.reader)
|
||||
transport, protocol = await loop.create_pipe_connection(self.address,
|
||||
factory)
|
||||
reader_protocol = asyncio.StreamReaderProtocol(
|
||||
self.reader, loop=loop)
|
||||
transport, protocol = await loop.create_pipe_connection(
|
||||
self.address, lambda: reader_protocol)
|
||||
self.writer = asyncio.StreamWriter(transport, protocol,
|
||||
self.reader, loop)
|
||||
|
||||
|
|
|
@ -1,25 +1,31 @@
|
|||
import unittest
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from artiq.protocols import pipe_ipc
|
||||
|
||||
|
||||
class IPCCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
if os.name == "nt":
|
||||
self.loop = asyncio.ProactorEventLoop()
|
||||
else:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
async def _coro_test(self, child_blocking):
|
||||
async def _coro_test(self, child_asyncio):
|
||||
ipc = pipe_ipc.AsyncioParentComm()
|
||||
await ipc.create_subprocess(sys.executable,
|
||||
sys.modules[__name__].__file__,
|
||||
str(child_asyncio),
|
||||
ipc.get_address())
|
||||
for i in range(10):
|
||||
ipc.write("{}\n".format(i).encode())
|
||||
await ipc.drain()
|
||||
s = (await ipc.readline()).decode()
|
||||
self.assertEqual(int(s), i+1)
|
||||
ipc.write(b"-1\n")
|
||||
|
@ -27,16 +33,49 @@ class IPCCase(unittest.TestCase):
|
|||
ipc.close()
|
||||
|
||||
def test_blocking(self):
|
||||
self.loop.run_until_complete(self._coro_test(False))
|
||||
|
||||
def test_asyncio(self):
|
||||
self.loop.run_until_complete(self._coro_test(True))
|
||||
|
||||
|
||||
def run_child():
|
||||
child_comm = pipe_ipc.ChildComm(sys.argv[1])
|
||||
def run_child_blocking():
|
||||
child_comm = pipe_ipc.ChildComm(sys.argv[2])
|
||||
while True:
|
||||
x = int(child_comm.readline().decode())
|
||||
if x < 0:
|
||||
break
|
||||
child_comm.write((str(x+1) + "\n").encode())
|
||||
child_comm.close()
|
||||
|
||||
|
||||
async def coro_child():
|
||||
child_comm = pipe_ipc.AsyncioChildComm(sys.argv[2])
|
||||
await child_comm.connect()
|
||||
while True:
|
||||
x = int((await child_comm.readline()).decode())
|
||||
if x < 0:
|
||||
break
|
||||
child_comm.write((str(x+1) + "\n").encode())
|
||||
await child_comm.drain()
|
||||
child_comm.close()
|
||||
|
||||
|
||||
def run_child_asyncio():
|
||||
if os.name == "nt":
|
||||
loop = asyncio.ProactorEventLoop()
|
||||
asyncio.set_event_loop(loop)
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(coro_child())
|
||||
loop.close()
|
||||
|
||||
|
||||
def run_child():
|
||||
if sys.argv[1] == "True":
|
||||
run_child_asyncio()
|
||||
else:
|
||||
run_child_blocking()
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_child()
|
||||
|
|
Loading…
Reference in New Issue