Allow `access_log` format override by web plugins (#733)

* Return DEFAULT_404_RESPONSE by default from static server when file doesnt exist

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix web server with proxy test

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Abhinav Singh 2021-11-13 02:59:43 +05:30 committed by GitHub
parent 684c0d4fe7
commit 094e30d31f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 57 deletions

View File

@ -81,13 +81,14 @@ DEFAULT_KEY_FILE = None
DEFAULT_LOG_FILE = None
DEFAULT_LOG_FORMAT = '%(asctime)s - pid:%(process)d [%(levelname)-.1s] %(module)s.%(funcName)s:%(lineno)d - %(message)s'
DEFAULT_LOG_LEVEL = 'INFO'
DEFAULT_WEB_ACCESS_LOG_FORMAT = '{client_addr} - {request_method} {request_path} - {connection_time_ms}ms'
DEFAULT_HTTP_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \
'{request_method} {server_host}:{server_port}{request_path} - ' + \
'{response_code} {response_reason} - {response_bytes} bytes - ' + \
'{connection_time_ms} ms'
'{connection_time_ms}ms'
DEFAULT_HTTPS_ACCESS_LOG_FORMAT = '{client_ip}:{client_port} - ' + \
'{request_method} {server_host}:{server_port} - ' + \
'{response_bytes} bytes - {connection_time_ms} ms'
'{response_bytes} bytes - {connection_time_ms}ms'
DEFAULT_NUM_ACCEPTORS = 0
DEFAULT_NUM_WORKERS = 0
DEFAULT_OPEN_FILE_LIMIT = 1024

View File

@ -12,7 +12,7 @@ import socket
import argparse
from abc import ABC, abstractmethod
from typing import List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID
from ..websocket import WebsocketFrame
@ -111,3 +111,13 @@ class HttpWebServerBasePlugin(ABC):
# def on_websocket_close(self) -> None:
# """Called when websocket connection has been closed."""
# raise NotImplementedError() # pragma: no cover
def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Use this method to override default access log format (see
DEFAULT_WEB_ACCESS_LOG_FORMAT) or to add/update/modify passed context
for usage by default access logger.
Return updated log context to use for default logging format, OR
Return None if plugin has logged the request.
"""
return context

View File

@ -19,7 +19,7 @@ from typing import List, Tuple, Optional, Dict, Union, Any, Pattern
from ...common.constants import DEFAULT_STATIC_SERVER_DIR, PROXY_AGENT_HEADER_VALUE
from ...common.constants import DEFAULT_ENABLE_STATIC_SERVER, DEFAULT_ENABLE_WEB_SERVER
from ...common.constants import DEFAULT_MIN_COMPRESSION_LIMIT
from ...common.constants import DEFAULT_MIN_COMPRESSION_LIMIT, DEFAULT_WEB_ACCESS_LOG_FORMAT
from ...common.utils import bytes_, text_, build_http_response, build_websocket_handshake_response
from ...common.types import Readables, Writables
from ...common.flag import flags
@ -79,6 +79,7 @@ class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
reason=b'NOT FOUND',
headers={
b'Server': PROXY_AGENT_HEADER_VALUE,
b'Content-Length': b'0',
b'Connection': b'close',
},
),
@ -90,6 +91,7 @@ class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
reason=b'NOT IMPLEMENTED',
headers={
b'Server': PROXY_AGENT_HEADER_VALUE,
b'Content-Length': b'0',
b'Connection': b'close',
},
),
@ -129,29 +131,32 @@ class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
@staticmethod
def read_and_build_static_file_response(path: str, min_compression_limit: int) -> memoryview:
with open(path, 'rb') as f:
content = f.read()
content_type = mimetypes.guess_type(path)[0]
if content_type is None:
content_type = 'text/plain'
headers = {
b'Content-Type': bytes_(content_type),
b'Cache-Control': b'max-age=86400',
b'Connection': b'close',
}
do_compress = len(content) > min_compression_limit
if do_compress:
headers.update({
b'Content-Encoding': b'gzip',
})
return memoryview(
build_http_response(
httpStatusCodes.OK,
reason=b'OK',
headers=headers,
body=gzip.compress(content) if do_compress else content,
),
)
try:
with open(path, 'rb') as f:
content = f.read()
content_type = mimetypes.guess_type(path)[0]
if content_type is None:
content_type = 'text/plain'
headers = {
b'Content-Type': bytes_(content_type),
b'Cache-Control': b'max-age=86400',
b'Connection': b'close',
}
do_compress = len(content) > min_compression_limit
if do_compress:
headers.update({
b'Content-Encoding': b'gzip',
})
return memoryview(
build_http_response(
httpStatusCodes.OK,
reason=b'OK',
headers=headers,
body=gzip.compress(content) if do_compress else content,
),
)
except FileNotFoundError:
return HttpWebServerPlugin.DEFAULT_404_RESPONSE
def try_upgrade(self) -> bool:
if self.request.has_header(b'connection') and \
@ -215,16 +220,13 @@ class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
# No-route found, try static serving if enabled
if self.flags.enable_static_server:
path = text_(path).split('?')[0]
try:
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
self.flags.min_compression_limit,
),
)
return True
except FileNotFoundError:
pass
self.client.queue(
self.read_and_build_static_file_response(
self.flags.static_server_dir + path,
self.flags.min_compression_limit,
),
)
return True
# Catch all unhandled web server requests, return 404
self.client.queue(self.DEFAULT_404_RESPONSE)
@ -301,19 +303,26 @@ class HttpWebServerPlugin(HttpProtocolHandlerPlugin):
def on_client_connection_close(self) -> None:
if self.request.has_host():
return
context = {
'client_addr': self.client.address,
'request_method': text_(self.request.method),
'request_path': text_(self.request.path),
'connection_time_ms': '%.2f' % ((time.time() - self.start_time) * 1000),
}
log_handled = False
if self.route:
# May be merge on_client_connection_close and on_access_log???
# probably by simply deprecating on_client_connection_close in future.
self.route.on_client_connection_close()
self.access_log()
ctx = self.route.on_access_log(context)
if ctx is None:
log_handled = True
else:
context = ctx
if not log_handled:
self.access_log(context)
# TODO: Allow plugins to customize access_log, similar
# to how proxy server plugins are able to do it.
def access_log(self) -> None:
logger.info(
'%s - %s %s - %.2f ms' %
(
self.client.address,
text_(self.request.method),
text_(self.request.path),
(time.time() - self.start_time) * 1000,
),
)
def access_log(self, context: Dict[str, Any]) -> None:
logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context))

View File

@ -16,8 +16,9 @@ import urllib.error
from proxy import TestCase
from proxy.common.constants import DEFAULT_CLIENT_RECVBUF_SIZE, PROXY_AGENT_HEADER_VALUE
from proxy.common.utils import socket_connection, build_http_request, build_http_response
from proxy.http.parser import httpStatusCodes, httpMethods
from proxy.common.utils import socket_connection, build_http_request
from proxy.http.server import HttpWebServerPlugin
from proxy.http.parser import httpMethods
@unittest.skipIf(os.name == 'nt', 'Disabled for Windows due to weird permission issues.')
@ -44,13 +45,7 @@ class TestProxyPyEmbedded(TestCase):
response = conn.recv(DEFAULT_CLIENT_RECVBUF_SIZE)
self.assertEqual(
response,
build_http_response(
httpStatusCodes.NOT_FOUND, reason=b'NOT FOUND',
headers={
b'Server': PROXY_AGENT_HEADER_VALUE,
b'Connection': b'close',
},
),
HttpWebServerPlugin.DEFAULT_404_RESPONSE.tobytes(),
)
def test_proxy_vcr(self) -> None: