Skip to content

Commit

Permalink
[BUG] Always run telemetry codepath (#3275)
Browse files Browse the repository at this point in the history
Adds tests for #3270

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
Co-authored-by: Sammy Sidhu <[email protected]>
  • Loading branch information
3 people authored Nov 13, 2024
1 parent 0d2bb2a commit 2c59675
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 10 deletions.
6 changes: 2 additions & 4 deletions daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@ def refresh_logger() -> None:

from daft.analytics import init_analytics

dev_build = get_build_type() == "dev"
user_opted_out = os.getenv("DAFT_ANALYTICS_ENABLED") == "0"
if not dev_build and not user_opted_out:
analytics_client = init_analytics(get_version(), get_build_type())
analytics_client.track_import()
analytics_client = init_analytics(get_version(), get_build_type(), user_opted_out)
analytics_client.track_import()

###
# Daft top-level imports
Expand Down
12 changes: 6 additions & 6 deletions daft/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,11 @@ def __init__(
self,
daft_version: str,
daft_build_type: str,
enabled: bool,
publish_payload_function: Callable[[AnalyticsClient, dict[str, Any]], None] = _post_segment_track_endpoint,
buffer_capacity: int = 100,
) -> None:
self._is_active = True
self._is_active = enabled
self._daft_version = daft_version
self._daft_build_type = daft_build_type
self._session_key = _get_session_key()
Expand All @@ -105,9 +106,6 @@ def __init__(
self._buffer_capacity = buffer_capacity
self._buffer: list[AnalyticsEvent] = []

def _enable_analytics(self) -> None:
self._is_active = True

def _append_to_log(self, event_name: str, data: dict[str, Any]) -> None:
if self._is_active:
self._buffer.append(
Expand Down Expand Up @@ -170,18 +168,20 @@ def track_fn_call(self, fn_name: str, duration_seconds: float, error: str | None
)


def init_analytics(daft_version: str, daft_build_type: str) -> AnalyticsClient:
def init_analytics(daft_version: str, daft_build_type: str, user_opted_out: bool) -> AnalyticsClient:
"""Initialize the analytics module
Returns:
AnalyticsClient: initialized singleton AnalyticsClient
"""
enabled = (not user_opted_out) and daft_build_type != "dev"

global _ANALYTICS_CLIENT

if _ANALYTICS_CLIENT is not None:
return _ANALYTICS_CLIENT

_ANALYTICS_CLIENT = AnalyticsClient(daft_version, daft_build_type)
_ANALYTICS_CLIENT = AnalyticsClient(daft_version, daft_build_type, enabled)
atexit.register(_ANALYTICS_CLIENT._flush)
return _ANALYTICS_CLIENT

Expand Down
19 changes: 19 additions & 0 deletions tests/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def mock_analytics() -> tuple[AnalyticsClient, MagicMock]:
client = AnalyticsClient(
daft.get_version(),
daft.get_build_type(),
True,
publish_payload_function=mock_publish,
buffer_capacity=1,
)
Expand Down Expand Up @@ -79,6 +80,7 @@ def test_analytics_client_timeout(
analytics_client = AnalyticsClient(
daft.get_version(),
daft.get_build_type(),
True,
buffer_capacity=1,
)

Expand All @@ -96,12 +98,29 @@ def test_analytics_client_timeout_2(
analytics_client = AnalyticsClient(
daft.get_version(),
daft.get_build_type(),
True,
buffer_capacity=1,
)
analytics_client.track_import()
mock_urlopen.assert_called_once()


@patch("urllib.request.urlopen")
def test_analytics_client_disabled(
mock_urlopen: MagicMock,
):
mock_urlopen.side_effect = urllib.error.URLError(socket.timeout("Timeout"))
analytics_client = AnalyticsClient(
daft.get_version(),
daft.get_build_type(),
False,
buffer_capacity=1,
)

analytics_client.track_import()
mock_urlopen.assert_not_called()


@patch("daft.analytics.datetime")
def test_analytics_client_track_dataframe_method(
mock_datetime: MagicMock, mock_analytics: tuple[AnalyticsClient, MagicMock]
Expand Down
13 changes: 13 additions & 0 deletions tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import subprocess
import sys

script_to_test = """
import daft
print(daft.context.get_context()._runner_config is None)
"""


def test_fresh_context_on_import():
"""Test that a freshly imported context doesn't have a runner config set"""
result = subprocess.run([sys.executable, "-c", script_to_test], capture_output=True)
assert result.stdout.decode().strip() == "True"

0 comments on commit 2c59675

Please sign in to comment.