Skip to content

Commit

Permalink
added V3 configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
ersi-lightly committed Nov 17, 2023
1 parent 27794d0 commit 6de5dee
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
22 changes: 11 additions & 11 deletions lightly/api/api_workflow_compute_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
DockerWorkerConfigV3Lightly,
DockerWorkerRegistryEntryData,
DockerWorkerType,
SelectionConfig,
SelectionConfigEntry,
SelectionConfigEntryInput,
SelectionConfigEntryStrategy,
SelectionConfigV3,
SelectionConfigV3Entry,
SelectionConfigV3EntryInput,
SelectionConfigV3EntryStrategy,
TagData,
)
from lightly.openapi_generated.swagger_client.rest import ApiException
Expand Down Expand Up @@ -175,7 +175,7 @@ def create_compute_worker_config(
self,
worker_config: Optional[Dict[str, Any]] = None,
lightly_config: Optional[Dict[str, Any]] = None,
selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None,
selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None,
) -> str:
"""Creates a new configuration for a Lightly Worker run.
Expand Down Expand Up @@ -269,7 +269,7 @@ def schedule_compute_worker_run(
self,
worker_config: Optional[Dict[str, Any]] = None,
lightly_config: Optional[Dict[str, Any]] = None,
selection_config: Optional[Union[Dict[str, Any], SelectionConfig]] = None,
selection_config: Optional[Union[Dict[str, Any], SelectionConfigV3]] = None,
priority: str = DockerRunScheduledPriority.MID,
runs_on: Optional[List[str]] = None,
) -> str:
Expand Down Expand Up @@ -634,17 +634,17 @@ def get_compute_worker_run_tags(self, run_id: str) -> List[TagData]:
return tags_in_dataset


def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfig:
def selection_config_from_dict(cfg: Dict[str, Any]) -> SelectionConfigV3:
"""Recursively converts selection config from dict to a SelectionConfig instance."""
strategies = []
for entry in cfg.get("strategies", []):
new_entry = copy.deepcopy(entry)
new_entry["input"] = SelectionConfigEntryInput(**entry["input"])
new_entry["strategy"] = SelectionConfigEntryStrategy(**entry["strategy"])
strategies.append(SelectionConfigEntry(**new_entry))
new_entry["input"] = SelectionConfigV3EntryInput(**entry["input"])
new_entry["strategy"] = SelectionConfigV3EntryStrategy(**entry["strategy"])
strategies.append(SelectionConfigV3Entry(**new_entry))
new_cfg = copy.deepcopy(cfg)
new_cfg["strategies"] = strategies
return SelectionConfig(**new_cfg)
return SelectionConfigV3(**new_cfg)


_T = TypeVar("_T")
Expand Down
20 changes: 16 additions & 4 deletions tests/api_workflow/test_api_workflow_compute_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@
SelectionStrategyType,
TagData,
)
from lightly.openapi_generated.swagger_client.models.selection_config_v3 import (
SelectionConfigV3,
)
from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry import (
SelectionConfigV3Entry,
)
from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_input import (
SelectionConfigV3EntryInput,
)
from lightly.openapi_generated.swagger_client.models.selection_config_v3_entry_strategy import (
SelectionConfigV3EntryStrategy,
)
from lightly.openapi_generated.swagger_client.rest import ApiException
from tests.api_workflow import utils
from tests.api_workflow.mocked_api_workflow_client import MockedApiWorkflowSetup
Expand Down Expand Up @@ -101,16 +113,16 @@ def test_create_compute_worker_config__selection_config_is_class(self) -> None:
"batch_size": 64,
},
},
selection_config=SelectionConfig(
selection_config=SelectionConfigV3(
n_samples=20,
strategies=[
SelectionConfigEntry(
input=SelectionConfigEntryInput(
SelectionConfigV3Entry(
input=SelectionConfigV3EntryInput(
type=SelectionInputType.EMBEDDINGS,
dataset_id=utils.generate_id(),
tag_name="some-tag-name",
),
strategy=SelectionConfigEntryStrategy(
strategy=SelectionConfigV3EntryStrategy(
type=SelectionStrategyType.SIMILARITY,
),
)
Expand Down

0 comments on commit 6de5dee

Please sign in to comment.