streamline string handling: single-source logic by simply converting data read from wrapped stream to bytes immediately
parent
56e9bf45fd
commit
52f2c9824b
|
@ -52,6 +52,19 @@ if not _py2():
|
||||||
file = NotImplemented # pylint: disable=invalid-name
|
file = NotImplemented # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def _to_bytes(data):
|
||||||
|
# type: (AnyStr) -> bytes
|
||||||
|
"""Convert input data from either string or bytes to bytes.
|
||||||
|
|
||||||
|
:param data: Data to convert
|
||||||
|
:returns: ``data`` converted to bytes
|
||||||
|
:rtype: bytes
|
||||||
|
"""
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data
|
||||||
|
return data.encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
class Base64IO(io.IOBase):
|
class Base64IO(io.IOBase):
|
||||||
"""Base64 stream with context manager support.
|
"""Base64 stream with context manager support.
|
||||||
|
|
||||||
|
@ -209,7 +222,7 @@ class Base64IO(io.IOBase):
|
||||||
self.write(line)
|
self.write(line)
|
||||||
|
|
||||||
def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
|
def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
|
||||||
# type: (AnyStr, int) -> AnyStr
|
# type: (bytes, int) -> bytes
|
||||||
"""Read additional data from wrapped stream until we reach the desired number of bytes.
|
"""Read additional data from wrapped stream until we reach the desired number of bytes.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
@ -226,20 +239,20 @@ class Base64IO(io.IOBase):
|
||||||
# case the base64 module happily removes any whitespace.
|
# case the base64 module happily removes any whitespace.
|
||||||
return data
|
return data
|
||||||
|
|
||||||
_data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO()
|
_data_buffer = io.BytesIO()
|
||||||
join_char = b"" if isinstance(data, bytes) else u""
|
|
||||||
_data_buffer.write(join_char.join(data.split())) # type: ignore
|
_data_buffer.write(b"".join(data.split()))
|
||||||
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore
|
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
|
||||||
|
|
||||||
while _remaining_bytes_to_read > 0:
|
while _remaining_bytes_to_read > 0:
|
||||||
_raw_additional_data = self.__wrapped.read(_remaining_bytes_to_read)
|
_raw_additional_data = _to_bytes(self.__wrapped.read(_remaining_bytes_to_read))
|
||||||
if not _raw_additional_data:
|
if not _raw_additional_data:
|
||||||
# No more data to read from wrapped stream.
|
# No more data to read from wrapped stream.
|
||||||
break
|
break
|
||||||
|
|
||||||
_data_buffer.write(join_char.join(_raw_additional_data.split())) # type: ignore
|
_data_buffer.write(b"".join(_raw_additional_data.split()))
|
||||||
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore
|
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
|
||||||
return _data_buffer.getvalue() # type: ignore
|
return _data_buffer.getvalue()
|
||||||
|
|
||||||
def read(self, b=-1):
|
def read(self, b=-1):
|
||||||
# type: (int) -> bytes
|
# type: (int) -> bytes
|
||||||
|
@ -271,14 +284,11 @@ class Base64IO(io.IOBase):
|
||||||
_bytes_to_read += 4 - _bytes_to_read % 4
|
_bytes_to_read += 4 - _bytes_to_read % 4
|
||||||
|
|
||||||
# Read encoded bytes from wrapped stream.
|
# Read encoded bytes from wrapped stream.
|
||||||
data = self.__wrapped.read(_bytes_to_read)
|
data = _to_bytes(self.__wrapped.read(_bytes_to_read))
|
||||||
# Remove whitespace from read data and attempt to read more data to get the desired
|
# Remove whitespace from read data and attempt to read more data to get the desired
|
||||||
# number of bytes.
|
# number of bytes.
|
||||||
whitespace = (
|
|
||||||
string.whitespace.encode("utf-8") if isinstance(data, bytes) else string.whitespace
|
|
||||||
) # type: Union[bytes, str]
|
|
||||||
|
|
||||||
if any([char in data for char in whitespace]):
|
if any([char in data for char in string.whitespace.encode("utf-8")]):
|
||||||
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)
|
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)
|
||||||
|
|
||||||
results = io.BytesIO()
|
results = io.BytesIO()
|
||||||
|
|
|
@ -37,3 +37,12 @@ def test_file():
|
||||||
# If we are in Python 3, the "file" assignment should happen
|
# If we are in Python 3, the "file" assignment should happen
|
||||||
# to provide a concrete definition of the "file" name.
|
# to provide a concrete definition of the "file" name.
|
||||||
assert base64io.file is NotImplemented
|
assert base64io.file is NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("source, expected", (
|
||||||
|
('asdf', b'asdf'),
|
||||||
|
(b'\x00\x01\x02\x03', b'\x00\x01\x02\x03'),
|
||||||
|
(u'\u1111\u2222', b'\xe1\x84\x91\xe2\x88\xa2')
|
||||||
|
))
|
||||||
|
def test_to_bytes(source, expected):
|
||||||
|
assert base64io._to_bytes(source) == expected
|
||||||
|
|
Loading…
Reference in New Issue