Skip to content

Commit

Permalink
SNOW-1778088 azure md5 (#2102)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeller authored Nov 7, 2024
1 parent 48fba63 commit 1e4d456
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 2 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne

# Release Notes

- v3.12.4(TBD)
- Fixed a bug where multipart uploads to Azure would be missing their MD5 hashes.

- v3.12.3(October 25,2024)
- Improved the error message for SSL-related issues to provide clearer guidance when an SSL error occurs.
- Improved error message for SQL execution cancellations caused by timeout.
Expand Down
26 changes: 24 additions & 2 deletions src/snowflake/connector/azure_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import base64
import json
import os
import xml.etree.ElementTree as ET
Expand All @@ -17,6 +18,7 @@
from .constants import FileHeader, ResultStatus
from .encryption_util import EncryptionMetadata
from .storage_client import SnowflakeStorageClient
from .util_text import get_md5
from .vendored import requests

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -149,7 +151,7 @@ def get_file_header(self, filename: str) -> FileHeader | None:
)
)
return FileHeader(
digest=r.headers.get("x-ms-meta-sfcdigest"),
digest=r.headers.get(SFCDIGEST),
content_length=int(r.headers.get("Content-Length")),
encryption_metadata=encryption_metadata,
)
Expand Down Expand Up @@ -236,7 +238,27 @@ def _complete_multipart_upload(self) -> None:
part = ET.Element("Latest")
part.text = block_id
root.append(part)
headers = {"x-ms-blob-content-encoding": "utf-8"}
# SNOW-1778088: We need to calculate the MD5 sum of this file for Azure Blob storage
new_stream = not bool(self.meta.src_stream or self.meta.intermediate_stream)
fd = (
self.meta.src_stream
or self.meta.intermediate_stream
or open(self.meta.real_src_file_name, "rb")
)
try:
if not new_stream:
# Reset position in file
fd.seek(0)
file_content = fd.read()
finally:
if new_stream:
fd.close()
headers = {
"x-ms-blob-content-encoding": "utf-8",
"x-ms-blob-content-md5": base64.b64encode(get_md5(file_content)).decode(
"utf-8"
),
}
azure_metadata = self._prepare_file_metadata()
headers.update(azure_metadata)
retry_id = "COMPLETE"
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/connector/util_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import hashlib
import logging
import random
import re
Expand Down Expand Up @@ -289,3 +290,11 @@ def random_string(
"""
random_part = "".join([random.Random().choice(choices) for _ in range(length)])
return "".join([prefix, random_part, suffix])


def get_md5(text: str | bytes) -> bytes:
if isinstance(text, str):
text = text.encode("utf-8")
md5 = hashlib.md5()
md5.update(text)
return md5.digest()
36 changes: 36 additions & 0 deletions test/integ/test_put_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,3 +791,39 @@ def test_get_multiple_files_with_same_name(tmp_path, conn_cnx, caplog):
# This is expected flakiness
pass
assert "Downloading multiple files with the same name" in caplog.text


@pytest.mark.skipolddriver
def test_put_md5(tmp_path, conn_cnx):
"""This test uploads a single and a multi part file and makes sure that md5 is populated."""
# Generate random files and folders
small_folder = tmp_path / "small"
big_folder = tmp_path / "big"
small_folder.mkdir()
big_folder.mkdir()
generate_k_lines_of_n_files(3, 1, tmp_dir=str(small_folder))
# This generate an about 342M file, we want the file big enough to trigger a multipart upload
generate_k_lines_of_n_files(3_000_000, 1, tmp_dir=str(big_folder))

small_test_file = small_folder / "file0"
big_test_file = big_folder / "file0"

stage_name = random_string(5, "test_put_md5_")
with conn_cnx() as cnx:
with cnx.cursor() as cur:
cur.execute(f"create temporary stage {stage_name}")
small_filename_in_put = str(small_test_file).replace("\\", "/")
big_filename_in_put = str(big_test_file).replace("\\", "/")
cur.execute(
f"PUT 'file://{small_filename_in_put}' @{stage_name}/small AUTO_COMPRESS = FALSE"
)
cur.execute(
f"PUT 'file://{big_filename_in_put}' @{stage_name}/big AUTO_COMPRESS = FALSE"
)

assert all(
map(
lambda e: e[2] is not None,
cur.execute(f"LS @{stage_name}").fetchall(),
)
)

0 comments on commit 1e4d456

Please sign in to comment.