Skip to content

Commit

Permalink
Create one API client per thread (#135)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jacobtomlinson and pre-commit-ci[bot] authored Aug 8, 2023
1 parent 2434932 commit 051a526
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 7 deletions.
8 changes: 6 additions & 2 deletions kr8s/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextlib
import json
import ssl
import threading
import weakref
from typing import Dict, List, Tuple, Union

Expand All @@ -29,7 +30,7 @@ class Api(object):
"""

_asyncio = True
_instances = weakref.WeakValueDictionary()
_instances = {}

def __init__(self, **kwargs) -> None:
if not kwargs.pop("bypass_factory", False):
Expand All @@ -48,7 +49,10 @@ def __init__(self, **kwargs) -> None:
serviceaccount=self._serviceaccount,
namespace=kwargs.get("namespace"),
)
Api._instances[frozenset(kwargs.items())] = self
thread_id = threading.get_ident()
if thread_id not in Api._instances:
Api._instances[thread_id] = weakref.WeakValueDictionary()
Api._instances[thread_id][frozenset(kwargs.items())] = self

def __await__(self):
async def f():
Expand Down
21 changes: 16 additions & 5 deletions kr8s/asyncio/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-FileCopyrightText: Copyright (c) 2023, Dask Developers, NVIDIA
# SPDX-License-Identifier: BSD 3-Clause License
import threading

from kr8s._api import Api as _AsyncApi


Expand All @@ -12,7 +14,7 @@ async def api(
) -> _AsyncApi:
"""Create a :class:`kr8s.Api` object for interacting with the Kubernetes API.
If a kr8s object already exists with the same arguments, it will be returned.
If a kr8s object already exists with the same arguments in this thread, it will be returned.
"""

from kr8s import Api as _SyncApi
Expand All @@ -24,10 +26,19 @@ async def api(

async def _f(**kwargs):
key = frozenset(kwargs.items())
if key in _cls._instances:
return await _cls._instances[key]
if all(k is None for k in kwargs.values()) and list(_cls._instances.values()):
return await list(_cls._instances.values())[0]
thread_id = threading.get_ident()
if (
_cls._instances
and thread_id in _cls._instances
and key in _cls._instances[thread_id]
):
return await _cls._instances[thread_id][key]
if (
all(k is None for k in kwargs.values())
and thread_id in _cls._instances
and list(_cls._instances[thread_id].values())
):
return await list(_cls._instances[thread_id].values())[0]
return await _cls(**kwargs, bypass_factory=True)

return await _f(
Expand Down
34 changes: 34 additions & 0 deletions kr8s/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA
# SPDX-License-Identifier: BSD 3-Clause License
import asyncio
import queue
import threading

import anyio
import pytest

import kr8s
Expand Down Expand Up @@ -32,6 +35,37 @@ async def test_api_factory(serviceaccount):
assert p.api is not k3


def test_api_factory_threaded():
assert len(kr8s.Api._instances) == 0

q = queue.Queue()

def run_in_thread(q):
async def create_api(q):
k = await kr8s.asyncio.api()
q.put(k)

anyio.run(create_api, q)

t1 = threading.Thread(
target=run_in_thread,
args=(q,),
)
t2 = threading.Thread(
target=run_in_thread,
args=(q,),
)
t1.start()
t2.start()
t1.join()
t2.join()
k1 = q.get()
k2 = q.get()

assert k1 is not k2
assert type(k1) is type(k2)


async def test_api_factory_with_kubeconfig(k8s_cluster, serviceaccount):
k1 = await kr8s.asyncio.api(kubeconfig=k8s_cluster.kubeconfig_path)
k2 = await kr8s.asyncio.api(serviceaccount=serviceaccount)
Expand Down

0 comments on commit 051a526

Please sign in to comment.