From dcea6780c6b520cdb489b88528d7d5eb570209ae Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Tue, 26 Jan 2016 00:48:12 +0100 Subject: [PATCH] protocols/pipe_ipc: implement AsyncioChildComm for Unix (affected by python/asyncio#314) --- artiq/protocols/pipe_ipc.py | 46 ++++++++++++++++++++++++------------ artiq/test/pipe_ipc.py | 47 +++++++++++++++++++++++++++++++++---- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/artiq/protocols/pipe_ipc.py b/artiq/protocols/pipe_ipc.py index 085e0eea8..a982bdf9b 100644 --- a/artiq/protocols/pipe_ipc.py +++ b/artiq/protocols/pipe_ipc.py @@ -1,5 +1,6 @@ import os import asyncio +from asyncio.streams import FlowControlMixin class _BaseIO: @@ -20,6 +21,22 @@ class _BaseIO: if os.name != "nt": + async def _fds_to_asyncio(rfd, wfd, loop): + reader = asyncio.StreamReader(loop=loop) + reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + + wf = open(wfd, "wb", 0) + transport, protocol = await loop.connect_write_pipe( + FlowControlMixin, wf) + writer = asyncio.StreamWriter(transport, protocol, + None, loop) + + rf = open(rfd, "rb", 0) + await loop.connect_read_pipe(lambda: reader_protocol, rf) + + return reader, writer + + class AsyncioParentComm(_BaseIO): def __init__(self): self.c_rfd, self.p_wfd = os.pipe() @@ -35,20 +52,19 @@ if os.name != "nt": os.close(self.c_rfd) os.close(self.c_wfd) - pipe = open(self.p_rfd, "rb", 0) - self.reader = asyncio.StreamReader(loop=loop) - def factory(): - return asyncio.StreamReaderProtocol(self.reader, loop=loop) - await loop.connect_read_pipe(factory, pipe) + self.reader, self.writer = await _fds_to_asyncio( + self.p_rfd, self.p_wfd, loop) - pipe = open(self.p_wfd, "wb", 0) - transport, protocol = await loop.connect_write_pipe( - asyncio.Protocol, pipe) - self.writer = asyncio.StreamWriter(transport, protocol, - None, loop) class AsyncioChildComm(_BaseIO): - pass + def __init__(self, address): + self.address = address + + async def connect(self): + rfd, wfd = self.address.split(",", maxsplit=1) + self.reader, self.writer = await _fds_to_asyncio( + int(rfd), int(wfd), asyncio.get_event_loop()) + class ChildComm: def __init__(self, address): @@ -82,10 +98,10 @@ else: # windows async def connect(self): loop = asyncio.get_event_loop() self.reader = asyncio.StreamReader(loop=loop) - def factory(): - return asyncio.StreamReaderProtocol(self.reader) - transport, protocol = await loop.create_pipe_connection(self.address, - factory) + reader_protocol = asyncio.StreamReaderProtocol( + self.reader, loop=loop) + transport, protocol = await loop.create_pipe_connection( + self.address, lambda: reader_protocol) self.writer = asyncio.StreamWriter(transport, protocol, self.reader, loop) diff --git a/artiq/test/pipe_ipc.py b/artiq/test/pipe_ipc.py index 901c49fa6..9614a7431 100644 --- a/artiq/test/pipe_ipc.py +++ b/artiq/test/pipe_ipc.py @@ -1,25 +1,31 @@ import unittest import sys import asyncio +import os from artiq.protocols import pipe_ipc class IPCCase(unittest.TestCase): def setUp(self): - self.loop = asyncio.new_event_loop() + if os.name == "nt": + self.loop = asyncio.ProactorEventLoop() + else: + self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) def tearDown(self): self.loop.close() - async def _coro_test(self, child_blocking): + async def _coro_test(self, child_asyncio): ipc = pipe_ipc.AsyncioParentComm() await ipc.create_subprocess(sys.executable, sys.modules[__name__].__file__, + str(child_asyncio), ipc.get_address()) for i in range(10): ipc.write("{}\n".format(i).encode()) + await ipc.drain() s = (await ipc.readline()).decode() self.assertEqual(int(s), i+1) ipc.write(b"-1\n") @@ -27,16 +33,49 @@ class IPCCase(unittest.TestCase): ipc.close() def test_blocking(self): + self.loop.run_until_complete(self._coro_test(False)) + + def test_asyncio(self): self.loop.run_until_complete(self._coro_test(True)) -def run_child(): - child_comm = pipe_ipc.ChildComm(sys.argv[1]) +def run_child_blocking(): + child_comm = pipe_ipc.ChildComm(sys.argv[2]) while True: x = int(child_comm.readline().decode()) if x < 0: break child_comm.write((str(x+1) + "\n").encode()) + child_comm.close() + + +async def coro_child(): + child_comm = pipe_ipc.AsyncioChildComm(sys.argv[2]) + await child_comm.connect() + while True: + x = int((await child_comm.readline()).decode()) + if x < 0: + break + child_comm.write((str(x+1) + "\n").encode()) + await child_comm.drain() + child_comm.close() + + +def run_child_asyncio(): + if os.name == "nt": + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(coro_child()) + loop.close() + + +def run_child(): + if sys.argv[1] == "True": + run_child_asyncio() + else: + run_child_blocking() if __name__ == "__main__": run_child()