From 3c168f7f79d1da2323d35dcf88c2d3c8730e5df6 Mon Sep 17 00:00:00 2001 From: Kumar Aditya Date: Thu, 19 Dec 2024 17:08:32 +0530 Subject: [PATCH] gh-128013: fix data race in `PyUnicode_AsUTF8AndSize` on free-threading (#128021) --- Lib/test/test_capi/test_unicode.py | 20 +++++++++++- Objects/unicodeobject.c | 49 +++++++++++++++++++----------- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/Lib/test/test_capi/test_unicode.py b/Lib/test/test_capi/test_unicode.py index 65d8242ad3f..3408c10f426 100644 --- a/Lib/test/test_capi/test_unicode.py +++ b/Lib/test/test_capi/test_unicode.py @@ -1,7 +1,7 @@ import unittest import sys from test import support -from test.support import import_helper +from test.support import threading_helper try: import _testcapi @@ -1005,6 +1005,24 @@ def test_asutf8(self): self.assertRaises(TypeError, unicode_asutf8, [], 0) # CRASHES unicode_asutf8(NULL, 0) + @unittest.skipIf(_testcapi is None, 'need _testcapi module') + @threading_helper.requires_working_threading() + def test_asutf8_race(self): + """Test that there's no race condition in PyUnicode_AsUTF8()""" + unicode_asutf8 = _testcapi.unicode_asutf8 + from threading import Thread + + data = "😊" + + def worker(): + for _ in range(1000): + self.assertEqual(unicode_asutf8(data, 5), b'\xf0\x9f\x98\x8a\0') + + threads = [Thread(target=worker) for _ in range(10)] + with threading_helper.start_threads(threads): + pass + + @support.cpython_only @unittest.skipIf(_testlimitedcapi is None, 'need _testlimitedcapi module') def test_asutf8andsize(self): diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 53be6f5b901..1aab9cf3776 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -114,7 +114,7 @@ NOTE: In the interpreter's initialization phase, some globals are currently static inline char* _PyUnicode_UTF8(PyObject *op) { - return (_PyCompactUnicodeObject_CAST(op)->utf8); + return FT_ATOMIC_LOAD_PTR_ACQUIRE(_PyCompactUnicodeObject_CAST(op)->utf8); } static inline char* PyUnicode_UTF8(PyObject *op) @@ -130,7 +130,7 @@ static inline char* PyUnicode_UTF8(PyObject *op) static inline void PyUnicode_SET_UTF8(PyObject *op, char *utf8) { - _PyCompactUnicodeObject_CAST(op)->utf8 = utf8; + FT_ATOMIC_STORE_PTR_RELEASE(_PyCompactUnicodeObject_CAST(op)->utf8, utf8); } static inline Py_ssize_t PyUnicode_UTF8_LENGTH(PyObject *op) @@ -700,16 +700,17 @@ _PyUnicode_CheckConsistency(PyObject *op, int check_content) CHECK(ascii->state.compact == 0); CHECK(data != NULL); if (ascii->state.ascii) { - CHECK(compact->utf8 == data); + CHECK(_PyUnicode_UTF8(op) == data); CHECK(compact->utf8_length == ascii->length); } else { - CHECK(compact->utf8 != data); + CHECK(_PyUnicode_UTF8(op) != data); } } - - if (compact->utf8 == NULL) +#ifndef Py_GIL_DISABLED + if (_PyUnicode_UTF8(op) == NULL) CHECK(compact->utf8_length == 0); +#endif } /* check that the best kind is used: O(n) operation */ @@ -1156,8 +1157,8 @@ resize_compact(PyObject *unicode, Py_ssize_t length) if (_PyUnicode_HAS_UTF8_MEMORY(unicode)) { PyMem_Free(_PyUnicode_UTF8(unicode)); - PyUnicode_SET_UTF8(unicode, NULL); PyUnicode_SET_UTF8_LENGTH(unicode, 0); + PyUnicode_SET_UTF8(unicode, NULL); } #ifdef Py_TRACE_REFS _Py_ForgetReference(unicode); @@ -1210,8 +1211,8 @@ resize_inplace(PyObject *unicode, Py_ssize_t length) if (!share_utf8 && _PyUnicode_HAS_UTF8_MEMORY(unicode)) { PyMem_Free(_PyUnicode_UTF8(unicode)); - PyUnicode_SET_UTF8(unicode, NULL); PyUnicode_SET_UTF8_LENGTH(unicode, 0); + PyUnicode_SET_UTF8(unicode, NULL); } data = (PyObject *)PyObject_Realloc(data, new_size); @@ -1221,8 +1222,8 @@ resize_inplace(PyObject *unicode, Py_ssize_t length) } _PyUnicode_DATA_ANY(unicode) = data; if (share_utf8) { - PyUnicode_SET_UTF8(unicode, data); PyUnicode_SET_UTF8_LENGTH(unicode, length); + PyUnicode_SET_UTF8(unicode, data); } _PyUnicode_LENGTH(unicode) = length; PyUnicode_WRITE(PyUnicode_KIND(unicode), data, length, 0); @@ -4216,6 +4217,21 @@ PyUnicode_FSDecoder(PyObject* arg, void* addr) static int unicode_fill_utf8(PyObject *unicode); + +static int +unicode_ensure_utf8(PyObject *unicode) +{ + int err = 0; + if (PyUnicode_UTF8(unicode) == NULL) { + Py_BEGIN_CRITICAL_SECTION(unicode); + if (PyUnicode_UTF8(unicode) == NULL) { + err = unicode_fill_utf8(unicode); + } + Py_END_CRITICAL_SECTION(); + } + return err; +} + const char * PyUnicode_AsUTF8AndSize(PyObject *unicode, Py_ssize_t *psize) { @@ -4227,13 +4243,11 @@ PyUnicode_AsUTF8AndSize(PyObject *unicode, Py_ssize_t *psize) return NULL; } - if (PyUnicode_UTF8(unicode) == NULL) { - if (unicode_fill_utf8(unicode) == -1) { - if (psize) { - *psize = -1; - } - return NULL; + if (unicode_ensure_utf8(unicode) == -1) { + if (psize) { + *psize = -1; } + return NULL; } if (psize) { @@ -5854,6 +5868,7 @@ unicode_encode_utf8(PyObject *unicode, _Py_error_handler error_handler, static int unicode_fill_utf8(PyObject *unicode) { + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(unicode); /* the string cannot be ASCII, or PyUnicode_UTF8() would be set */ assert(!PyUnicode_IS_ASCII(unicode)); @@ -5895,10 +5910,10 @@ unicode_fill_utf8(PyObject *unicode) PyErr_NoMemory(); return -1; } - PyUnicode_SET_UTF8(unicode, cache); - PyUnicode_SET_UTF8_LENGTH(unicode, len); memcpy(cache, start, len); cache[len] = '\0'; + PyUnicode_SET_UTF8_LENGTH(unicode, len); + PyUnicode_SET_UTF8(unicode, cache); _PyBytesWriter_Dealloc(&writer); return 0; }