Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions Doc/library/multiprocessing.shared_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -311,34 +311,9 @@ finishes execution.
existing :class:`!ShareableList`, specify its shared memory block's unique
name while leaving *sequence* set to ``None``.

.. note::

A known issue exists for :class:`bytes` and :class:`str` values.
If they end with ``\x00`` nul bytes or characters, those may be
*silently stripped* when fetching them by index from the
:class:`!ShareableList`. This ``.rstrip(b'\x00')`` behavior is
considered a bug and may go away in the future. See :gh:`106939`.

For applications where rstripping of trailing nulls is a problem,
work around it by always unconditionally appending an extra non-0
byte to the end of such values when storing and unconditionally
removing it when fetching:

.. doctest::

>>> from multiprocessing import shared_memory
>>> nul_bug_demo = shared_memory.ShareableList(['?\x00', b'\x03\x02\x01\x00\x00\x00'])
>>> nul_bug_demo[0]
'?'
>>> nul_bug_demo[1]
b'\x03\x02\x01'
>>> nul_bug_demo.shm.unlink()
>>> padded = shared_memory.ShareableList(['?\x00\x07', b'\x03\x02\x01\x00\x00\x00\x07'])
>>> padded[0][:-1]
'?\x00'
>>> padded[1][:-1]
b'\x03\x02\x01\x00\x00\x00'
>>> padded.shm.unlink()
.. versionchanged:: next
Trailing null bytes are preserved for :class:`bytes` and :class:`str`
values now. Previously they were stripped silently. See :gh:`106939`.

.. method:: count(value)

Expand Down
28 changes: 18 additions & 10 deletions Lib/multiprocessing/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ class ShareableList:
_alignment = 8
_back_transforms_mapping = {
0: lambda value: value, # int, float, bool
1: lambda value: value.rstrip(b'\x00').decode(_encoding), # str
2: lambda value: value.rstrip(b'\x00'), # bytes
1: lambda value: value.decode(_encoding), # str
2: lambda value: value, # bytes
3: lambda _value: None, # None
}

Expand Down Expand Up @@ -326,6 +326,15 @@ def __init__(self, sequence=None, *, name=None):
for fmt in _formats:
offset += self._alignment if fmt[-1] != "s" else int(fmt[:-1])
self._allocated_offsets.append(offset)
_stored_formats = []
for item, fmt in zip(sequence, _formats):
if isinstance(item, (str, bytes)):
encoded = (item.encode(_encoding)
if isinstance(item, str) else item)
_stored_formats.append("%ds" % len(encoded))
else:
_stored_formats.append(fmt)

_recreation_codes = [
self._extract_recreation_code(item) for item in sequence
]
Expand Down Expand Up @@ -359,7 +368,7 @@ def __init__(self, sequence=None, *, name=None):
self._format_packing_metainfo,
self.shm.buf,
self._offset_packing_formats,
*(v.encode(_enc) for v in _formats)
*(v.encode(_enc) for v in _stored_formats)
)
struct.pack_into(
self._format_back_transform_codes,
Expand Down Expand Up @@ -459,6 +468,7 @@ def __setitem__(self, position, value):

if not isinstance(value, (str, bytes)):
new_format = self._types_mapping[type(value)]
pack_format = new_format
encoded_value = value
else:
allocated_length = self._allocated_offsets[position + 1] - item_offset
Expand All @@ -467,19 +477,17 @@ def __setitem__(self, position, value):
if isinstance(value, str) else value)
if len(encoded_value) > allocated_length:
raise ValueError("bytes/str item exceeds available storage")
if current_format[-1] == "s":
new_format = current_format
else:
new_format = self._types_mapping[str] % (
allocated_length,
)
# Allocated-length format.
pack_format = "%ds" % allocated_length
# Actual-length format.
new_format = "%ds" % len(encoded_value)

self._set_packing_format_and_transform(
position,
new_format,
value
)
struct.pack_into(new_format, self.shm.buf, offset, encoded_value)
struct.pack_into(pack_format, self.shm.buf, offset, encoded_value)

def __reduce__(self):
return partial(self.__class__, name=self.shm.name), ()
Expand Down
32 changes: 30 additions & 2 deletions Lib/test/_test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4757,7 +4757,7 @@ def test_shared_memory_ShareableList_basics(self):
self.assertEqual(current_format, sl._get_packing_format(0))

# Verify attributes are readable.
self.assertEqual(sl.format, '8s8sdqxxxxxx?xxxxxxxx?q')
self.assertEqual(sl.format, '5s5sdqxxxxxx?xxxxxxxx?q')

# Exercise len().
self.assertEqual(len(sl), 7)
Expand Down Expand Up @@ -4785,7 +4785,7 @@ def test_shared_memory_ShareableList_basics(self):
self.assertEqual(sl[3], 42)
sl[4] = 'some' # Change type at a given position.
self.assertEqual(sl[4], 'some')
self.assertEqual(sl.format, '8s8sdq8sxxxxxxx?q')
self.assertEqual(sl.format, '5s5sdq4sxxxxxxx?q')
with self.assertRaisesRegex(ValueError,
"exceeds available storage"):
sl[4] = 'far too many'
Expand Down Expand Up @@ -4887,6 +4887,34 @@ def test_shared_memory_ShareableList_pickling_dead_object(self):
with self.assertRaises(FileNotFoundError):
pickle.loads(serialized_sl)

def test_shared_memory_ShareableList_trailing_nulls(self):
# gh-106939: ShareableList should preserve trailing null bytes
# in bytes and str values.
sl = shared_memory.ShareableList([
b'\x03\x02\x01\x00\x00\x00',
'?\x00',
b'\x00\x00\x00',
b'',
b'no nulls',
])
self.addCleanup(sl.shm.unlink)
self.addCleanup(sl.shm.close)

self.assertEqual(sl[0], b'\x03\x02\x01\x00\x00\x00')
self.assertEqual(sl[1], '?\x00')
self.assertEqual(sl[2], b'\x00\x00\x00')
self.assertEqual(sl[3], b'')
self.assertEqual(sl[4], b'no nulls')

sl2 = shared_memory.ShareableList(name=sl.shm.name)
self.addCleanup(sl2.shm.close)
self.assertEqual(sl2[0], b'\x03\x02\x01\x00\x00\x00')
self.assertEqual(sl2[1], '?\x00')
self.assertEqual(sl2[2], b'\x00\x00\x00')
self.assertEqual(sl2[3], b'')
self.assertEqual(sl2[4], b'no nulls')


def test_shared_memory_cleaned_after_process_termination(self):
cmd = '''if 1:
import os, time, sys
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
:class:`~multiprocessing.shared_memory.ShareableList` keeps the trailing
null bytes for :class:`bytes` and :class:`str` values. Previously they were
stripped silently.
Loading