protocols/pipe_ipc: implement AsyncioChildComm for Unix (affected by python/asyncio#314)

This commit is contained in:
Sebastien Bourdeauducq 2016-01-26 00:48:12 +01:00
parent 8befc6a8fc
commit dcea6780c6
2 changed files with 74 additions and 19 deletions

View File

@ -1,5 +1,6 @@
import os
import asyncio
from asyncio.streams import FlowControlMixin
class _BaseIO:
@ -20,6 +21,22 @@ class _BaseIO:
if os.name != "nt":
async def _fds_to_asyncio(rfd, wfd, loop):
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
wf = open(wfd, "wb", 0)
transport, protocol = await loop.connect_write_pipe(
FlowControlMixin, wf)
writer = asyncio.StreamWriter(transport, protocol,
None, loop)
rf = open(rfd, "rb", 0)
await loop.connect_read_pipe(lambda: reader_protocol, rf)
return reader, writer
class AsyncioParentComm(_BaseIO):
def __init__(self):
self.c_rfd, self.p_wfd = os.pipe()
@ -35,20 +52,19 @@ if os.name != "nt":
os.close(self.c_rfd)
os.close(self.c_wfd)
pipe = open(self.p_rfd, "rb", 0)
self.reader = asyncio.StreamReader(loop=loop)
def factory():
return asyncio.StreamReaderProtocol(self.reader, loop=loop)
await loop.connect_read_pipe(factory, pipe)
self.reader, self.writer = await _fds_to_asyncio(
self.p_rfd, self.p_wfd, loop)
pipe = open(self.p_wfd, "wb", 0)
transport, protocol = await loop.connect_write_pipe(
asyncio.Protocol, pipe)
self.writer = asyncio.StreamWriter(transport, protocol,
None, loop)
class AsyncioChildComm(_BaseIO):
pass
def __init__(self, address):
self.address = address
async def connect(self):
rfd, wfd = self.address.split(",", maxsplit=1)
self.reader, self.writer = await _fds_to_asyncio(
int(rfd), int(wfd), asyncio.get_event_loop())
class ChildComm:
def __init__(self, address):
@ -82,10 +98,10 @@ else: # windows
async def connect(self):
loop = asyncio.get_event_loop()
self.reader = asyncio.StreamReader(loop=loop)
def factory():
return asyncio.StreamReaderProtocol(self.reader)
transport, protocol = await loop.create_pipe_connection(self.address,
factory)
reader_protocol = asyncio.StreamReaderProtocol(
self.reader, loop=loop)
transport, protocol = await loop.create_pipe_connection(
self.address, lambda: reader_protocol)
self.writer = asyncio.StreamWriter(transport, protocol,
self.reader, loop)

View File

@ -1,25 +1,31 @@
import unittest
import sys
import asyncio
import os
from artiq.protocols import pipe_ipc
class IPCCase(unittest.TestCase):
def setUp(self):
self.loop = asyncio.new_event_loop()
if os.name == "nt":
self.loop = asyncio.ProactorEventLoop()
else:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
def tearDown(self):
self.loop.close()
async def _coro_test(self, child_blocking):
async def _coro_test(self, child_asyncio):
ipc = pipe_ipc.AsyncioParentComm()
await ipc.create_subprocess(sys.executable,
sys.modules[__name__].__file__,
str(child_asyncio),
ipc.get_address())
for i in range(10):
ipc.write("{}\n".format(i).encode())
await ipc.drain()
s = (await ipc.readline()).decode()
self.assertEqual(int(s), i+1)
ipc.write(b"-1\n")
@ -27,16 +33,49 @@ class IPCCase(unittest.TestCase):
ipc.close()
def test_blocking(self):
self.loop.run_until_complete(self._coro_test(False))
def test_asyncio(self):
self.loop.run_until_complete(self._coro_test(True))
def run_child():
child_comm = pipe_ipc.ChildComm(sys.argv[1])
def run_child_blocking():
child_comm = pipe_ipc.ChildComm(sys.argv[2])
while True:
x = int(child_comm.readline().decode())
if x < 0:
break
child_comm.write((str(x+1) + "\n").encode())
child_comm.close()
async def coro_child():
child_comm = pipe_ipc.AsyncioChildComm(sys.argv[2])
await child_comm.connect()
while True:
x = int((await child_comm.readline()).decode())
if x < 0:
break
child_comm.write((str(x+1) + "\n").encode())
await child_comm.drain()
child_comm.close()
def run_child_asyncio():
if os.name == "nt":
loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(loop)
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(coro_child())
loop.close()
def run_child():
if sys.argv[1] == "True":
run_child_asyncio()
else:
run_child_blocking()
if __name__ == "__main__":
run_child()