From 52f2c9824ba8940e24795c12bc6b72ecf4657960 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Fri, 2 Nov 2018 16:09:54 -0700 Subject: [PATCH] streamline string handling: single-source logic by simply converting data read from wrapped stream to bytes immediately --- src/base64io/__init__.py | 38 ++++++++++++++++++++++++-------------- test/unit/test_base64io.py | 9 +++++++++ 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/base64io/__init__.py b/src/base64io/__init__.py index 4822523..f99382d 100644 --- a/src/base64io/__init__.py +++ b/src/base64io/__init__.py @@ -52,6 +52,19 @@ if not _py2(): file = NotImplemented # pylint: disable=invalid-name +def _to_bytes(data): + # type: (AnyStr) -> bytes + """Convert input data from either string or bytes to bytes. + + :param data: Data to convert + :returns: ``data`` converted to bytes + :rtype: bytes + """ + if isinstance(data, bytes): + return data + return data.encode("utf-8") + + class Base64IO(io.IOBase): """Base64 stream with context manager support. @@ -209,7 +222,7 @@ class Base64IO(io.IOBase): self.write(line) def _read_additional_data_removing_whitespace(self, data, total_bytes_to_read): - # type: (AnyStr, int) -> AnyStr + # type: (bytes, int) -> bytes """Read additional data from wrapped stream until we reach the desired number of bytes. .. note:: @@ -226,20 +239,20 @@ class Base64IO(io.IOBase): # case the base64 module happily removes any whitespace. return data - _data_buffer = io.BytesIO() if isinstance(data, bytes) else io.StringIO() - join_char = b"" if isinstance(data, bytes) else u"" - _data_buffer.write(join_char.join(data.split())) # type: ignore - _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore + _data_buffer = io.BytesIO() + + _data_buffer.write(b"".join(data.split())) + _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() while _remaining_bytes_to_read > 0: - _raw_additional_data = self.__wrapped.read(_remaining_bytes_to_read) + _raw_additional_data = _to_bytes(self.__wrapped.read(_remaining_bytes_to_read)) if not _raw_additional_data: # No more data to read from wrapped stream. break - _data_buffer.write(join_char.join(_raw_additional_data.split())) # type: ignore - _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() # type: ignore - return _data_buffer.getvalue() # type: ignore + _data_buffer.write(b"".join(_raw_additional_data.split())) + _remaining_bytes_to_read = total_bytes_to_read - _data_buffer.tell() + return _data_buffer.getvalue() def read(self, b=-1): # type: (int) -> bytes @@ -271,14 +284,11 @@ class Base64IO(io.IOBase): _bytes_to_read += 4 - _bytes_to_read % 4 # Read encoded bytes from wrapped stream. - data = self.__wrapped.read(_bytes_to_read) + data = _to_bytes(self.__wrapped.read(_bytes_to_read)) # Remove whitespace from read data and attempt to read more data to get the desired # number of bytes. - whitespace = ( - string.whitespace.encode("utf-8") if isinstance(data, bytes) else string.whitespace - ) # type: Union[bytes, str] - if any([char in data for char in whitespace]): + if any([char in data for char in string.whitespace.encode("utf-8")]): data = self._read_additional_data_removing_whitespace(data, _bytes_to_read) results = io.BytesIO() diff --git a/test/unit/test_base64io.py b/test/unit/test_base64io.py index 07528b7..c93462d 100644 --- a/test/unit/test_base64io.py +++ b/test/unit/test_base64io.py @@ -37,3 +37,12 @@ def test_file(): # If we are in Python 3, the "file" assignment should happen # to provide a concrete definition of the "file" name. assert base64io.file is NotImplemented + + +@pytest.mark.parametrize("source, expected", ( + ('asdf', b'asdf'), + (b'\x00\x01\x02\x03', b'\x00\x01\x02\x03'), + (u'\u1111\u2222', b'\xe1\x84\x91\xe2\x88\xa2') +)) +def test_to_bytes(source, expected): + assert base64io._to_bytes(source) == expected