protocols/pipe_ipc: implement AsyncioChildComm for Unix (affected by python/asyncio#314)

This commit is contained in:
Sebastien Bourdeauducq 2016-01-26 00:48:12 +01:00
parent 8befc6a8fc
commit dcea6780c6
2 changed files with 74 additions and 19 deletions

View File

@ -1,5 +1,6 @@
import os import os
import asyncio import asyncio
from asyncio.streams import FlowControlMixin
class _BaseIO: class _BaseIO:
@ -20,6 +21,22 @@ class _BaseIO:
if os.name != "nt": if os.name != "nt":
async def _fds_to_asyncio(rfd, wfd, loop):
reader = asyncio.StreamReader(loop=loop)
reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
wf = open(wfd, "wb", 0)
transport, protocol = await loop.connect_write_pipe(
FlowControlMixin, wf)
writer = asyncio.StreamWriter(transport, protocol,
None, loop)
rf = open(rfd, "rb", 0)
await loop.connect_read_pipe(lambda: reader_protocol, rf)
return reader, writer
class AsyncioParentComm(_BaseIO): class AsyncioParentComm(_BaseIO):
def __init__(self): def __init__(self):
self.c_rfd, self.p_wfd = os.pipe() self.c_rfd, self.p_wfd = os.pipe()
@ -35,20 +52,19 @@ if os.name != "nt":
os.close(self.c_rfd) os.close(self.c_rfd)
os.close(self.c_wfd) os.close(self.c_wfd)
pipe = open(self.p_rfd, "rb", 0) self.reader, self.writer = await _fds_to_asyncio(
self.reader = asyncio.StreamReader(loop=loop) self.p_rfd, self.p_wfd, loop)
def factory():
return asyncio.StreamReaderProtocol(self.reader, loop=loop)
await loop.connect_read_pipe(factory, pipe)
pipe = open(self.p_wfd, "wb", 0)
transport, protocol = await loop.connect_write_pipe(
asyncio.Protocol, pipe)
self.writer = asyncio.StreamWriter(transport, protocol,
None, loop)
class AsyncioChildComm(_BaseIO): class AsyncioChildComm(_BaseIO):
pass def __init__(self, address):
self.address = address
async def connect(self):
rfd, wfd = self.address.split(",", maxsplit=1)
self.reader, self.writer = await _fds_to_asyncio(
int(rfd), int(wfd), asyncio.get_event_loop())
class ChildComm: class ChildComm:
def __init__(self, address): def __init__(self, address):
@ -82,10 +98,10 @@ else: # windows
async def connect(self): async def connect(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self.reader = asyncio.StreamReader(loop=loop) self.reader = asyncio.StreamReader(loop=loop)
def factory(): reader_protocol = asyncio.StreamReaderProtocol(
return asyncio.StreamReaderProtocol(self.reader) self.reader, loop=loop)
transport, protocol = await loop.create_pipe_connection(self.address, transport, protocol = await loop.create_pipe_connection(
factory) self.address, lambda: reader_protocol)
self.writer = asyncio.StreamWriter(transport, protocol, self.writer = asyncio.StreamWriter(transport, protocol,
self.reader, loop) self.reader, loop)

View File

@ -1,25 +1,31 @@
import unittest import unittest
import sys import sys
import asyncio import asyncio
import os
from artiq.protocols import pipe_ipc from artiq.protocols import pipe_ipc
class IPCCase(unittest.TestCase): class IPCCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.loop = asyncio.new_event_loop() if os.name == "nt":
self.loop = asyncio.ProactorEventLoop()
else:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
def tearDown(self): def tearDown(self):
self.loop.close() self.loop.close()
async def _coro_test(self, child_blocking): async def _coro_test(self, child_asyncio):
ipc = pipe_ipc.AsyncioParentComm() ipc = pipe_ipc.AsyncioParentComm()
await ipc.create_subprocess(sys.executable, await ipc.create_subprocess(sys.executable,
sys.modules[__name__].__file__, sys.modules[__name__].__file__,
str(child_asyncio),
ipc.get_address()) ipc.get_address())
for i in range(10): for i in range(10):
ipc.write("{}\n".format(i).encode()) ipc.write("{}\n".format(i).encode())
await ipc.drain()
s = (await ipc.readline()).decode() s = (await ipc.readline()).decode()
self.assertEqual(int(s), i+1) self.assertEqual(int(s), i+1)
ipc.write(b"-1\n") ipc.write(b"-1\n")
@ -27,16 +33,49 @@ class IPCCase(unittest.TestCase):
ipc.close() ipc.close()
def test_blocking(self): def test_blocking(self):
self.loop.run_until_complete(self._coro_test(False))
def test_asyncio(self):
self.loop.run_until_complete(self._coro_test(True)) self.loop.run_until_complete(self._coro_test(True))
def run_child(): def run_child_blocking():
child_comm = pipe_ipc.ChildComm(sys.argv[1]) child_comm = pipe_ipc.ChildComm(sys.argv[2])
while True: while True:
x = int(child_comm.readline().decode()) x = int(child_comm.readline().decode())
if x < 0: if x < 0:
break break
child_comm.write((str(x+1) + "\n").encode()) child_comm.write((str(x+1) + "\n").encode())
child_comm.close()
async def coro_child():
child_comm = pipe_ipc.AsyncioChildComm(sys.argv[2])
await child_comm.connect()
while True:
x = int((await child_comm.readline()).decode())
if x < 0:
break
child_comm.write((str(x+1) + "\n").encode())
await child_comm.drain()
child_comm.close()
def run_child_asyncio():
if os.name == "nt":
loop = asyncio.ProactorEventLoop()
asyncio.set_event_loop(loop)
else:
loop = asyncio.get_event_loop()
loop.run_until_complete(coro_child())
loop.close()
def run_child():
if sys.argv[1] == "True":
run_child_asyncio()
else:
run_child_blocking()
if __name__ == "__main__": if __name__ == "__main__":
run_child() run_child()