From ca1d1e713963c9e8266d0ffa6986df1064b46503 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Wed, 16 Oct 2019 13:09:38 -0700 Subject: [PATCH] os.close only for threadless (#138) * os.close only for Threadless to avoid fd leaks * Remove os.close mock which is only called for threadless --- proxy.py | 3 +- tests.py | 84 ++++++++++++-------------------------------------------- 2 files changed, 19 insertions(+), 68 deletions(-) diff --git a/proxy.py b/proxy.py index 6631abcc..aaad662b 100755 --- a/proxy.py +++ b/proxy.py @@ -970,6 +970,7 @@ class Threadless(multiprocessing.Process): **self.kwargs) try: self.works[fileno].initialize() + os.close(fileno) except ssl.SSLError as e: logger.exception('ssl.SSLError', exc_info=e) self.cleanup(fileno) @@ -1119,7 +1120,6 @@ class Acceptor(multiprocessing.Process): family=self.family, type=socket.SOCK_STREAM ) - os.close(fileno) try: self.selector.register(self.sock, selectors.EVENT_READ) self.start_threadless_process() @@ -2524,7 +2524,6 @@ class ProtocolHandler(threading.Thread, ThreadlessWork): conn = socket.fromfd( fileno, family=socket.AF_INET if self.config.hostname.version == 4 else socket.AF_INET6, type=socket.SOCK_STREAM) - os.close(fileno) return conn def optionally_wrap_socket( diff --git a/tests.py b/tests.py index b4d0e864..1445cb90 100644 --- a/tests.py +++ b/tests.py @@ -294,7 +294,6 @@ class TestWorker(unittest.TestCase): mock_protocol_handler, config=self.protocol_config) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -302,8 +301,7 @@ class TestWorker(unittest.TestCase): self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -319,7 +317,6 @@ class TestWorker(unittest.TestCase): sock.accept.assert_not_called() self.mock_protocol_handler.assert_not_called() - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -327,8 +324,7 @@ class TestWorker(unittest.TestCase): self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -344,7 +340,6 @@ class TestWorker(unittest.TestCase): self.mock_protocol_handler.assert_not_called() - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.recv_handle') @@ -352,8 +347,7 @@ class TestWorker(unittest.TestCase): self, mock_recv_handle: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: fileno = 10 conn = mock.MagicMock() addr = mock.MagicMock() @@ -1008,13 +1002,11 @@ class TestWebsocketClient(unittest.TestCase): class TestHttpProtocolHandler(unittest.TestCase): - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -1027,7 +1019,6 @@ class TestHttpProtocolHandler(unittest.TestCase): self.mock_selector = mock_selector self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1141,14 +1132,12 @@ class TestHttpProtocolHandler(unittest.TestCase): self.proxy.run_once() self.assertEqual(self.proxy.client.buffer, proxy.ProxyConnectionFailed.RESPONSE_PKT) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_proxy_authentication_failed( self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) config = proxy.ProtocolConfig( @@ -1158,7 +1147,6 @@ class TestHttpProtocolHandler(unittest.TestCase): b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', @@ -1170,15 +1158,13 @@ class TestHttpProtocolHandler(unittest.TestCase): self.proxy.client.buffer, proxy.ProxyAuthenticationFailed.RESPONSE_PKT) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_get( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) @@ -1194,7 +1180,6 @@ class TestHttpProtocolHandler(unittest.TestCase): self.proxy = proxy.ProtocolHandler( self.fileno, addr=self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1221,15 +1206,13 @@ class TestHttpProtocolHandler(unittest.TestCase): ]) self.assert_data_queued(mock_server_connection, server) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.TcpServerConnection') def test_authenticated_proxy_http_tunnel( self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 @@ -1244,7 +1227,6 @@ class TestHttpProtocolHandler(unittest.TestCase): self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() assert self.http_server_port is not None @@ -1338,10 +1320,9 @@ class TestHttpProtocolHandler(unittest.TestCase): class TestWebServerPlugin(unittest.TestCase): - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') - def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_os_close: mock.Mock) -> None: + def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value @@ -1351,20 +1332,16 @@ class TestWebServerPlugin(unittest.TestCase): b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_disk( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: pac_file = 'proxy.pac' self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) self.init_and_make_pac_file_request(pac_file) - mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1377,17 +1354,14 @@ class TestWebServerPlugin(unittest.TestCase): }, body=f.read() )) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_buffer( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' self.init_and_make_pac_file_request(proxy.text_(pac_file_content)) - mock_os_close.assert_called_with(self.fileno) self.proxy.run_once() self.assertEqual( self.proxy.request.state, @@ -1399,12 +1373,10 @@ class TestWebServerPlugin(unittest.TestCase): }, body=pac_file_content )) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_default_web_server_returns_404( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [( selectors.SelectorKey( @@ -1417,7 +1389,6 @@ class TestWebServerPlugin(unittest.TestCase): b'proxy.HttpProxyPlugin,proxy.HttpWebServerPlugin') self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self._conn.recv.return_value = proxy.CRLF.join([ b'GET /hello HTTP/1.1', @@ -1431,12 +1402,10 @@ class TestWebServerPlugin(unittest.TestCase): self.proxy.client.buffer, proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves( - self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') @@ -1488,14 +1457,12 @@ class TestWebServerPlugin(unittest.TestCase): body=html_file_content )) - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves_404( self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = proxy.build_http_request(b'GET', b'/not-found.html') @@ -1517,7 +1484,6 @@ class TestWebServerPlugin(unittest.TestCase): self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.proxy.run_once() @@ -1528,17 +1494,15 @@ class TestWebServerPlugin(unittest.TestCase): self.assertEqual(self._conn.send.call_args[0][0], proxy.HttpWebServerPlugin.DEFAULT_404_RESPONSE) - @mock.patch('os.close') @mock.patch('socket.fromfd') def test_on_client_connection_called_on_teardown( - self, mock_fromfd: mock.Mock, mock_os_close: mock.Mock) -> None: + self, mock_fromfd: mock.Mock) -> None: config = proxy.ProtocolConfig() plugin = mock.MagicMock() config.plugins = {b'ProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() plugin.assert_called() with mock.patch.object(self.proxy, 'run_once') as mock_run_once: @@ -1570,13 +1534,11 @@ class TestWebServerPlugin(unittest.TestCase): class TestHttpProxyPlugin(unittest.TestCase): - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector @@ -1591,7 +1553,6 @@ class TestHttpProxyPlugin(unittest.TestCase): self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() def test_proxy_plugin_initialized(self) -> None: @@ -1646,13 +1607,11 @@ class TestHttpProxyPlugin(unittest.TestCase): class TestHttpProxyPluginExamples(unittest.TestCase): - @mock.patch('os.close') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, - mock_selector: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.config = proxy.ProtocolConfig() @@ -1670,7 +1629,6 @@ class TestHttpProxyPluginExamples(unittest.TestCase): self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() @mock.patch('proxy.TcpServerConnection') @@ -1871,7 +1829,6 @@ class TestHttpProxyPluginExamples(unittest.TestCase): class TestHttpProxyTlsInterception(unittest.TestCase): - @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -1885,8 +1842,7 @@ class TestHttpProxyTlsInterception(unittest.TestCase): mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) @@ -1926,7 +1882,6 @@ class TestHttpProxyTlsInterception(unittest.TestCase): self._conn = mock_fromfd.return_value self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.plugin.assert_called() @@ -2005,7 +1960,6 @@ class TestHttpProxyTlsInterception(unittest.TestCase): class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): - @mock.patch('os.close') @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.TcpServerConnection') @@ -2018,8 +1972,7 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, - mock_ssl_wrap: mock.Mock, - mock_os_close: mock.Mock) -> None: + mock_ssl_wrap: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_popen = mock_popen @@ -2045,7 +1998,6 @@ class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): mock_fromfd.return_value = self._conn self.proxy = proxy.ProtocolHandler( self.fileno, self._addr, config=self.config) - mock_os_close.assert_called_with(self.fileno) self.proxy.initialize() self.server = self.mock_server_conn.return_value