Skip to content

Commit

Permalink
feat(weave): better debug HTTP logging with image payloads
Browse files Browse the repository at this point in the history
  • Loading branch information
jamie-rasmussen committed Dec 13, 2024
1 parent c964e3f commit ff3530a
Showing 1 changed file with 78 additions and 9 deletions.
87 changes: 78 additions & 9 deletions weave/trace_server/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,65 @@ def pprint_json(text: str) -> None:
console.print(Text(text, style=STYLE_ERROR))


def pprint_prepared_request(prepared_request: PreparedRequest) -> None:
class FileMagic:
"""Utility class for guessing file types.
imghdr is deprecated and we don't want to add a new dependency.
"""

mimetype: str
metadata: dict[str, Any]

@property
def is_image(self) -> bool:
return self.mimetype.startswith("image/")

def __repr__(self) -> str:
return f"<{self.mimetype} metadata={self.metadata}>"

@classmethod
def analyze(cls, data: bytes) -> "FileMagic":
result = cls()
result.mimetype = "application/octet-stream"
result.metadata = {
"length": len(data),
}

# TODO: Support other file type magic numbers as useful
# TODO: Might be nice to extract image dimensions into metadata
data_prefix = data[:8]
if data_prefix.startswith(b"\x89PNG\r\n\x1a\n"):
result.mimetype = "image/png"
elif data_prefix.startswith(b"\xff\xd8\xff"):
result.mimetype = "image/jpeg"
elif data_prefix.startswith(b"GIF87a") or data_prefix.startswith(b"GIF89a"):
result.mimetype = "image/gif"
return result


MAX_SHORT_RESPONSE_LENGTH = 1000


def get_truncated_str(s: str, max_length: int) -> str:
"""Truncate a string to a maximum length."""
n = len(s)
if n > max_length:
return f"{s[:max_length]}... (truncated, total length: {n})"
return s


def get_short_response(response_content: bytes) -> str:
"""Avoid printing extremely long strings."""
file_magic = FileMagic.analyze(response_content)
if file_magic.is_image:
return f"{file_magic}"
decoded = response_content.decode("utf-8")
return get_truncated_str(decoded, MAX_SHORT_RESPONSE_LENGTH)


def pprint_prepared_request(
prepared_request: PreparedRequest, verbose: bool = False
) -> None:
"""Pretty print a PreparedRequest."""
time_text = Text(
datetime.datetime.now().strftime("%H:%M:%S.%f"), style=STYLE_METADATA
Expand All @@ -95,7 +153,13 @@ def pprint_prepared_request(prepared_request: PreparedRequest) -> None:
if content_type == "application/json":
pprint_json(decode_str(prepared_request.body))
elif content_type and content_type.startswith("multipart/form-data"):
console.print(Text(decode_str(prepared_request.body), style=STYLE_BODY))
output = decode_str(prepared_request.body)
output = (
output
if verbose
else get_truncated_str(output, MAX_SHORT_RESPONSE_LENGTH)
)
console.print(Text(output, style=STYLE_BODY))
elif isinstance(prepared_request.body, str):
console.print(f"{prepared_request.body}", style=STYLE_BODY)
else:
Expand All @@ -105,7 +169,7 @@ def pprint_prepared_request(prepared_request: PreparedRequest) -> None:
console.print(Text(" None", style=STYLE_NONE))


def pprint_response(response: Response) -> None:
def pprint_response(response: Response, verbose: bool = False) -> None:
"""Pretty print a Response."""
status_style = STYLE_STATUS_OTHER
if 200 <= response.status_code < 300:
Expand All @@ -123,18 +187,23 @@ def pprint_response(response: Response) -> None:
console.print(Text("Body:", style=STYLE_LABEL))
if response.headers.get("Content-Type") == "application/json":
pprint_json(response.text)
elif response.text:
console.print(Text(response.text, style=STYLE_BODY))
elif response.content:
output = response.text if verbose else get_short_response(response.content)
console.print(Text(output, style=STYLE_BODY))
else:
console.print(" None", style=STYLE_NONE)


class LoggingHTTPAdapter(HTTPAdapter):
def __init__(self, verbose: bool = False) -> None:
super().__init__()
self.verbose = verbose

# Actual signature is:
# self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
def send(self, request: PreparedRequest, **kwargs: Any) -> Response: # type: ignore
console.print(Text("-" * 21, style=STYLE_DIVIDER_REQUEST))
pprint_prepared_request(request)
pprint_prepared_request(request, self.verbose)
start_time = time()
response = super().send(request, **kwargs)
elapsed_time = time() - start_time
Expand All @@ -143,13 +212,13 @@ def send(self, request: PreparedRequest, **kwargs: Any) -> Response: # type: ig
Text("Elapsed Time: ", style=STYLE_LABEL)
+ Text(f"{elapsed_time:.2f} seconds", style=STYLE_METADATA)
)
pprint_response(response)
pprint_response(response, self.verbose)
return response


session = Session()
if os.environ.get("WEAVE_DEBUG_HTTP") == "1":
adapter = LoggingHTTPAdapter()
if os.environ.get("WEAVE_DEBUG_HTTP") in ("1", "2"):
adapter = LoggingHTTPAdapter(os.environ.get("WEAVE_DEBUG_HTTP") == "2")
session.mount("http://", adapter)
session.mount("https://", adapter)

Expand Down

0 comments on commit ff3530a

Please sign in to comment.