streamline string handling: single-source logic by simply converting data read from wrapped stream to bytes immediately

development
mattsb42-aws 2018-11-02 16:09:54 -07:00
parent 56e9bf45fd
commit 52f2c9824b
2 changed files with 33 additions and 14 deletions

View File

@ -52,6 +52,19 @@ if not _py2():
file = NotImplemented # pylint: disable=invalid-name file = NotImplemented # pylint: disable=invalid-name
def _to_bytes(data):
# type: (AnyStr) -> bytes
"""Convert input data from either string or bytes to bytes.
:param data: Data to convert
:returns: ``data`` converted to bytes
:rtype: bytes
"""
if isinstance(data, bytes):
return data
return data.encode("utf-8")
class Base64IO(io.IOBase): class Base64IO(io.IOBase):
"""Base64 stream with context manager support. """Base64 stream with context manager support.
@ -209,7 +222,7 @@ class Base64IO(io.IOBase):
self.write(line) self.write(line)
def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read): def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read):
# type: (AnyStr, int) -> AnyStr # type: (bytes, int) -> bytes
"""Read additional data from wrapped stream until we reach the desired number of bytes. """Read additional data from wrapped stream until we reach the desired number of bytes.
.. note:: .. note::
@ -226,20 +239,20 @@ class Base64IO(io.IOBase):
# case the base64 module happily removes any whitespace. # case the base64 module happily removes any whitespace.
return data return data
_data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO() _data_buffer = io.BytesIO()
join_char = b"" if isinstance(data, bytes) else u""
_data_buffer.write(join_char.join(data.split())) # type: ignore _data_buffer.write(b"".join(data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
while _remaining_bytes_to_read > 0: while _remaining_bytes_to_read > 0:
_raw_additional_data = self.__wrapped.read(_remaining_bytes_to_read) _raw_additional_data = _to_bytes(self.__wrapped.read(_remaining_bytes_to_read))
if not _raw_additional_data: if not _raw_additional_data:
# No more data to read from wrapped stream. # No more data to read from wrapped stream.
break break
_data_buffer.write(join_char.join(_raw_additional_data.split())) # type: ignore _data_buffer.write(b"".join(_raw_additional_data.split()))
_remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell()
return _data_buffer.getvalue() # type: ignore return _data_buffer.getvalue()
def read(self, b=-1): def read(self, b=-1):
# type: (int) -> bytes # type: (int) -> bytes
@ -271,14 +284,11 @@ class Base64IO(io.IOBase):
_bytes_to_read += 4 - _bytes_to_read % 4 _bytes_to_read += 4 - _bytes_to_read % 4
# Read encoded bytes from wrapped stream. # Read encoded bytes from wrapped stream.
data = self.__wrapped.read(_bytes_to_read) data = _to_bytes(self.__wrapped.read(_bytes_to_read))
# Remove whitespace from read data and attempt to read more data to get the desired # Remove whitespace from read data and attempt to read more data to get the desired
# number of bytes. # number of bytes.
whitespace = (
string.whitespace.encode("utf-8") if isinstance(data, bytes) else string.whitespace
) # type: Union[bytes, str]
if any([char in data for char in whitespace]): if any([char in data for char in string.whitespace.encode("utf-8")]):
data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) data = self._read_additional_data_removing_whitespace(data, _bytes_to_read)
results = io.BytesIO() results = io.BytesIO()

View File

@ -37,3 +37,12 @@ def test_file():
# If we are in Python 3, the "file" assignment should happen # If we are in Python 3, the "file" assignment should happen
# to provide a concrete definition of the "file" name. # to provide a concrete definition of the "file" name.
assert base64io.file is NotImplemented assert base64io.file is NotImplemented
@pytest.mark.parametrize("source, expected", (
('asdf', b'asdf'),
(b'\x00\x01\x02\x03', b'\x00\x01\x02\x03'),
(u'\u1111\u2222', b'\xe1\x84\x91\xe2\x88\xa2')
))
def test_to_bytes(source, expected):
assert base64io._to_bytes(source) == expected