diff --git a/docs/changelog-fragments/638.bugfix.rst b/docs/changelog-fragments/638.bugfix.rst new file mode 100644 index 000000000..ac1ff6077 --- /dev/null +++ b/docs/changelog-fragments/638.bugfix.rst @@ -0,0 +1,5 @@ +Fixed reading files over SFTP that go over the pre-defined chunk size. + +Prior to this change, the files could end up being corrupted, ending up with the last read chunk written to the file instead of the entire payload. + +-- by :user:`Jakuje` diff --git a/src/pylibsshext/sftp.pyx b/src/pylibsshext/sftp.pyx index 6331c529d..6220ad8ae 100644 --- a/src/pylibsshext/sftp.pyx +++ b/src/pylibsshext/sftp.pyx @@ -91,16 +91,16 @@ cdef class SFTP: if rf is NULL: raise LibsshSFTPException("Opening remote file [%s] for read failed with error [%s]" % (remote_file, self._get_sftp_error_str())) - while True: - file_data = sftp.sftp_read(rf, read_buffer, sizeof(char) * 1024) - if file_data == 0: - break - elif file_data < 0: - sftp.sftp_close(rf) - raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]" - % (remote_file, self._get_sftp_error_str())) - - with open(local_file, 'wb+') as f: + with open(local_file, 'wb') as f: + while True: + file_data = sftp.sftp_read(rf, read_buffer, sizeof(char) * 1024) + if file_data == 0: + break + elif file_data < 0: + sftp.sftp_close(rf) + raise LibsshSFTPException("Reading data from remote file [%s] failed with error [%s]" + % (remote_file, self._get_sftp_error_str())) + bytes_written = f.write(read_buffer[:file_data]) if bytes_written and file_data != bytes_written: sftp.sftp_close(rf) diff --git a/tests/unit/sftp_test.py b/tests/unit/sftp_test.py index fc70c4512..ce1a28ffe 100644 --- a/tests/unit/sftp_test.py +++ b/tests/unit/sftp_test.py @@ -2,6 +2,8 @@ """Tests suite for sftp.""" +import random +import string import uuid import pytest @@ -18,11 +20,21 @@ def sftp_session(ssh_client_session): del sftp_sess # noqa: WPS420 -@pytest.fixture -def transmit_payload(): - """Generate a binary test payload.""" - uuid_name = uuid.uuid4() - return 'Hello, {name!s}'.format(name=uuid_name).encode() +@pytest.fixture( + params=(32, 1024 + 1), + ids=('small-payload', 'large-payload'), +) +def transmit_payload(request: pytest.FixtureRequest) -> bytes: + """Generate binary test payloads of assorted sizes. + + The choice 32 is arbitrary small value. + + The choice 1024 + 1 is meant to be 1B larger than the chunk size used in + :file:`sftp.pyx` to make sure we excercise at least two rounds of reading/writing. + """ + payload_len = request.param + random_bytes = [ord(random.choice(string.printable)) for _ in range(payload_len)] + return bytes(random_bytes) @pytest.fixture @@ -48,6 +60,21 @@ def dst_path(file_paths_pair): return path +@pytest.fixture +def other_payload(): + """Generate a binary test payload.""" + uuid_name = uuid.uuid4() + return 'Original content: {name!s}'.format(name=uuid_name).encode() + + +@pytest.fixture +def pre_existing_dst_path(dst_path, other_payload): + """Return a data destination path.""" + dst_path.write_bytes(other_payload) + assert dst_path.exists() + return dst_path + + def test_make_sftp(sftp_session): """Smoke-test SFTP instance creation.""" assert sftp_session @@ -63,3 +90,15 @@ def test_get(dst_path, src_path, sftp_session, transmit_payload): """Check that SFTP file download works.""" sftp_session.get(str(src_path), str(dst_path)) assert dst_path.read_bytes() == transmit_payload + + +def test_get_existing(pre_existing_dst_path, src_path, sftp_session, transmit_payload): + """Check that SFTP file download works when target file exists.""" + sftp_session.get(str(src_path), str(pre_existing_dst_path)) + assert pre_existing_dst_path.read_bytes() == transmit_payload + + +def test_put_existing(pre_existing_dst_path, src_path, sftp_session, transmit_payload): + """Check that SFTP file upload works when target file exists.""" + sftp_session.put(str(src_path), str(pre_existing_dst_path)) + assert pre_existing_dst_path.read_bytes() == transmit_payload