From b91f25a193a7a5bbf5ec6dd7d418a9a0831d044f Mon Sep 17 00:00:00 2001 From: Prodesire Date: Tue, 27 Mar 2018 20:41:01 +0800 Subject: [PATCH] add request.update_query_params which update query params of given url and return new url --- pydu/compat.py | 2 ++ pydu/request.py | 14 +++++++++++++- tests/test_request.py | 8 +++++++- 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/pydu/compat.py b/pydu/compat.py index 3cf3a09..2986958 100644 --- a/pydu/compat.py +++ b/pydu/compat.py @@ -17,10 +17,12 @@ if PY2: import urllib as urlib import urlparse from urlparse import urljoin + from urllib import urlencode else: import urllib.request as urlib import urllib.parse as urlparse from urllib.parse import urljoin + from urllib.parse import urlencode # Dictionary iteration if PY2: diff --git a/pydu/request.py b/pydu/request.py index e2e552a..75acb88 100644 --- a/pydu/request.py +++ b/pydu/request.py @@ -5,7 +5,7 @@ import socket from . import logger from .string import safeunicode -from .compat import PY2, string_types, urlparse, urlib +from .compat import PY2, string_types, urlparse, urlib, urlencode class FileName(object): @@ -125,3 +125,15 @@ def check_connect(ip, port, retry=1, timeout=0.5): finally: retry -= 1 return None + + +def update_query_params(url, params): + """ + Update query params of given url and return new url. + """ + parts = list(urlparse.urlparse(url)) + query = dict(urlparse.parse_qsl(parts[4])) + query.update(params) + parts[4] = urlencode(query) + new_url = urlparse.urlunparse(parts) + return new_url diff --git a/tests/test_request.py b/tests/test_request.py index 3cffc1d..59e6865 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -2,7 +2,7 @@ import socket from .testing import mockserver import pydu.request from pydu.network import get_free_port -from pydu.request import FileName, check_connect +from pydu.request import FileName, check_connect, update_query_params def test_filename_from_url(): @@ -52,3 +52,9 @@ def test_check_connect(port=None): pydu.request.socket.socket = mock_socket assert not check_connect('127.0.0.1', port=port, timeout=0.01) + + +def test_update_query_params(): + assert update_query_params('http://example.com/', {'foo': 1}) == 'http://example.com/?foo=1' + assert update_query_params('http://example.com/?foo=1', {'foo': 2}) == 'http://example.com/?foo=2' + assert update_query_params('http://example.com/?foo=1', {'foo': 2, 'bar': 3}) == 'http://example.com/?foo=2&bar=3'