diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 423df765300..44b42ffa115 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -247,6 +247,17 @@ def testInterpreterCrash(self): except socket.error: pass + def testNtoH(self): + def twice(f): + def g(x): + return f(f(x)) + return g + for i in (0, 1, 0xffff0000, 2L, (2**32L) - 1): + self.assertEqual(i, twice(socket.htonl)(i)) + self.assertEqual(i, twice(socket.ntohl)(i)) + self.assertRaises(OverflowError, socket.htonl, 2L**34) + self.assertRaises(OverflowError, socket.ntohl, 2L**34) + def testGetServByName(self): """Testing getservbyname().""" if hasattr(socket, 'getservbyname'): diff --git a/Modules/socketmodule.c b/Modules/socketmodule.c index dee042b2061..39809f43fe0 100644 --- a/Modules/socketmodule.c +++ b/Modules/socketmodule.c @@ -2469,15 +2469,34 @@ Convert a 16-bit integer from network to host byte order."); static PyObject * -socket_ntohl(PyObject *self, PyObject *args) +socket_ntohl(PyObject *self, PyObject *arg) { - int x1, x2; + unsigned long x; - if (!PyArg_ParseTuple(args, "i:ntohl", &x1)) { - return NULL; + if (PyInt_Check(arg)) { + x = PyInt_AS_LONG(arg); } - x2 = ntohl(x1); - return PyInt_FromLong(x2); + else if (PyLong_Check(arg)) { + x = PyLong_AsUnsignedLong(arg); +#if SIZEOF_LONG > 4 + { + unsigned long y; + /* only want the trailing 32 bits */ + y = x & 0xFFFFFFFFUL; + if (y ^ x) + return PyErr_Format(PyExc_OverflowError, + "long int larger than 32 bits"); + x = y; + } +#endif + } + else + return PyErr_Format(PyExc_TypeError, + "expected int/long, %s found", + arg->ob_type->tp_name); + if (x == (unsigned long) -1 && PyErr_Occurred()) + return NULL; + return PyInt_FromLong(ntohl(x)); } PyDoc_STRVAR(ntohl_doc, @@ -2489,7 +2508,7 @@ Convert a 32-bit integer from network to host byte order."); static PyObject * socket_htons(PyObject *self, PyObject *args) { - int x1, x2; + unsigned long x1, x2; if (!PyArg_ParseTuple(args, "i:htons", &x1)) { return NULL; @@ -2505,15 +2524,34 @@ Convert a 16-bit integer from host to network byte order."); static PyObject * -socket_htonl(PyObject *self, PyObject *args) +socket_htonl(PyObject *self, PyObject *arg) { - int x1, x2; + unsigned long x; - if (!PyArg_ParseTuple(args, "i:htonl", &x1)) { - return NULL; + if (PyInt_Check(arg)) { + x = PyInt_AS_LONG(arg); } - x2 = htonl(x1); - return PyInt_FromLong(x2); + else if (PyLong_Check(arg)) { + x = PyLong_AsUnsignedLong(arg); +#if SIZEOF_LONG > 4 + { + unsigned long y; + /* only want the trailing 32 bits */ + y = x & 0xFFFFFFFFUL; + if (y ^ x) + return PyErr_Format(PyExc_OverflowError, + "long int larger than 32 bits"); + x = y; + } +#endif + } + else + return PyErr_Format(PyExc_TypeError, + "expected int/long, %s found", + arg->ob_type->tp_name); + if (x == (unsigned long) -1 && PyErr_Occurred()) + return NULL; + return PyInt_FromLong(htonl(x)); } PyDoc_STRVAR(htonl_doc, @@ -2812,11 +2850,11 @@ static PyMethodDef socket_methods[] = { {"ntohs", socket_ntohs, METH_VARARGS, ntohs_doc}, {"ntohl", socket_ntohl, - METH_VARARGS, ntohl_doc}, + METH_O, ntohl_doc}, {"htons", socket_htons, METH_VARARGS, htons_doc}, {"htonl", socket_htonl, - METH_VARARGS, htonl_doc}, + METH_O, htonl_doc}, {"inet_aton", socket_inet_aton, METH_VARARGS, inet_aton_doc}, {"inet_ntoa", socket_inet_ntoa,