From df3fb3a1e96a8ea3be6b0bdbbad9005e18ba30ee Mon Sep 17 00:00:00 2001 From: Jules Bertrand <33326907+julesbertrand@users.noreply.github.com> Date: Mon, 26 Feb 2024 18:03:51 +0100 Subject: [PATCH] fix: init command failing (#147) --- deployer/cli.py | 25 +++++++------- deployer/utils/config.py | 1 - deployer/utils/console.py | 18 ++++------ tests/unit_tests/test_console.py | 58 +++++++++++++++++--------------- 4 files changed, 50 insertions(+), 52 deletions(-) diff --git a/deployer/cli.py b/deployer/cli.py index 5b9aac8..f68a704 100644 --- a/deployer/cli.py +++ b/deployer/cli.py @@ -8,7 +8,7 @@ import typer from loguru import logger from pydantic import ValidationError -from rich.prompt import Prompt +from rich.prompt import Confirm, Prompt from typing_extensions import Annotated from deployer import constants @@ -493,7 +493,9 @@ def create_pipeline( if existing_pipelines: raise typer.BadParameter(f"Pipelines {existing_pipelines} already exist.") - logger.info(f"Creating pipeline {pipeline_names} with config type {config_type}") + console.print( + f"Creating pipeline {pipeline_names} with config type: [bold]{config_type}[/bold]" + ) for pipeline_name in pipeline_names: pipeline_filepath = deployer_settings.pipelines_root_path / f"{pipeline_name}.py" @@ -514,7 +516,11 @@ def create_pipeline( pipeline_filepath.unlink() raise e - logger.success(f"Pipeline {pipeline_name} created with configs in {config_dirpath}") + console.print( + f"Pipeline '{pipeline_name}' created at '{pipeline_filepath}'" + f" with config files: {[str(p) for p in config_dirpath.glob('*')]}. :sparkles:", + style="blue", + ) @app.command(name="init") @@ -524,7 +530,7 @@ def init_deployer(ctx: typer.Context): # noqa: C901 console.print("Welcome to Vertex Deployer!", style="blue") console.print("This command will help you getting fired up.", style="blue") - if Prompt.ask("Do you want to configure the deployer?", choices=["y", "n"]) == "y": + if Confirm.ask("Do you want to configure the deployer?"): pyproject_toml_filepath = find_pyproject_toml(Path.cwd().resolve()) if pyproject_toml_filepath is None: @@ -542,7 +548,7 @@ def init_deployer(ctx: typer.Context): # noqa: C901 update_pyproject_toml(pyproject_toml_filepath, new_deployer_settings) console.print("Configuration saved in pyproject.toml :sparkles:", style="blue") - if Prompt.ask("Do you want to build default folder structure", choices=["y", "n"]) == "y": + if Confirm.ask("Do you want to build default folder structure"): def create_file_or_dir(path: Path, text: str = ""): """Create a file (if text is provided) or a directory at path. Warn if path exists.""" @@ -563,14 +569,13 @@ def create_file_or_dir(path: Path, text: str = ""): Path("./.env"), "=\n".join(VertexPipelinesSettings.model_json_schema()["required"]) ) - if Prompt.ask("Do you want to create a pipeline?", choices=["y", "n"]) == "y": + if Confirm.ask("Do you want to create a pipeline?"): wrong_name = True while wrong_name: pipeline_name = Prompt.ask("What is the name of the pipeline?") - pipeline_path = Path(deployer_settings.pipelines_root_path) / f"{pipeline_name}.py" try: - create_pipeline(pipeline_name=pipeline_name) + create_pipeline(ctx, pipeline_names=[pipeline_name]) except typer.BadParameter as e: console.print(e, style="red") except FileExistsError: @@ -580,10 +585,6 @@ def create_file_or_dir(path: Path, text: str = ""): ) else: wrong_name = False - console.print( - f"Pipeline '{pipeline_name}' created at '{pipeline_path}'. :sparkles:", - style="blue", - ) console.print("All done :sparkles:", style="blue") diff --git a/deployer/utils/config.py b/deployer/utils/config.py index 0ebb4ef..cba1495 100644 --- a/deployer/utils/config.py +++ b/deployer/utils/config.py @@ -32,7 +32,6 @@ def load_vertex_settings(env_file: Optional[Path] = None) -> VertexPipelinesSett """Load the settings from the environment.""" try: settings = VertexPipelinesSettings(_env_file=env_file, _env_file_encoding="utf-8") - print(settings) except ValidationError as e: msg = "Validation failed for VertexPipelinesSettings. " if env_file is not None: diff --git a/deployer/utils/console.py b/deployer/utils/console.py index 029bd97..eeda8c9 100644 --- a/deployer/utils/console.py +++ b/deployer/utils/console.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from rich.console import Console -from rich.prompt import Prompt +from rich.prompt import Confirm, Prompt console = Console() @@ -21,10 +21,8 @@ def ask_user_for_model_fields(model: Type[BaseModel]) -> dict: set_fields = {} for field_name, field_info in model.model_fields.items(): if isclass(field_info.annotation) and issubclass(field_info.annotation, BaseModel): - answer = Prompt.ask( - f"Do you want to configure command {field_name}?", choices=["y", "n"], default="n" - ) - if answer == "y": + answer = Confirm.ask(f"Do you want to configure command {field_name}?", default=False) + if answer: set_fields[field_name] = ask_user_for_model_fields(field_info.annotation) else: @@ -36,13 +34,9 @@ def ask_user_for_model_fields(model: Type[BaseModel]) -> dict: choices = list(annotation.__members__) if isclass(annotation) and annotation == bool: - choices = ["y", "n"] - default = "y" if field_info.default else "n" - - answer = Prompt.ask(field_name, default=default, choices=choices) - - if isclass(annotation) and annotation == bool: - answer = answer == "y" + answer = Confirm.ask(field_name, default=default) + else: + answer = Prompt.ask(field_name, default=default, choices=choices) if answer != field_info.default: set_fields[field_name] = answer diff --git a/tests/unit_tests/test_console.py b/tests/unit_tests/test_console.py index 06f20c5..7ab4333 100644 --- a/tests/unit_tests/test_console.py +++ b/tests/unit_tests/test_console.py @@ -8,36 +8,38 @@ class TestAskUserForModelFields: def test_ask_user_for_input_for_each_field(self): - # Arrange + # Given class TestModel(BaseModel): field1: str field2: int field3: bool = False - # Act + # When with patch("rich.prompt.Prompt.ask") as mock_prompt: - mock_prompt.side_effect = ["value1", 2, "y"] - result = ask_user_for_model_fields(TestModel) + with patch("rich.prompt.Confirm.ask") as mock_confirm: + mock_confirm.side_effect = [True] + mock_prompt.side_effect = ["value1", 2] + result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": "value1", "field2": 2, "field3": True} def test_ask_user_with_boolean_fields(self): - # Arrange + # Given class TestModel(BaseModel): field1: bool field2: bool = True - # Act - with patch("rich.prompt.Prompt.ask") as mock_prompt: - mock_prompt.side_effect = ["y", "n"] + # When + with patch("rich.prompt.Confirm.ask") as mock_confirm: + mock_confirm.side_effect = [True, False] result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": True, "field2": False} def test_ask_user_with_enum_fields(self): - # Arrange + # Given class TestEnum(Enum): OPTION1 = "Option 1" OPTION2 = "Option 2" @@ -46,16 +48,16 @@ class TestEnum(Enum): class TestModel(BaseModel): field1: TestEnum - # Act + # When with patch("rich.prompt.Prompt.ask") as mock_prompt: mock_prompt.side_effect = ["Option 2"] result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": "Option 2"} def test_ask_user_with_nested_models(self): - # Arrange + # Given class NestedModel(BaseModel): nested_field1: str nested_field2: int @@ -63,47 +65,49 @@ class NestedModel(BaseModel): class TestModel(BaseModel): field1: NestedModel - # Act + # When with patch("rich.prompt.Prompt.ask") as mock_prompt: - mock_prompt.side_effect = ["y", "value1", 2] - result = ask_user_for_model_fields(TestModel) + with patch("rich.prompt.Confirm.ask") as mock_confirm: + mock_confirm.side_effect = [True] + mock_prompt.side_effect = ["value1", 2] + result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": {"nested_field1": "value1", "nested_field2": 2}} def test_ask_user_with_no_fields(self): - # Arrange + # Given class TestModel(BaseModel): pass - # Act + # When result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {} def test_ask_user_with_no_default_value_and_no_valid_choices(self): - # Arrange + # Given class TestModel(BaseModel): field1: str - # Act + # When with patch("rich.prompt.Prompt.ask") as mock_prompt: mock_prompt.side_effect = ["value1"] result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": "value1"} def test_ask_user_with_default_value_and_no_valid_choices(self): - # Arrange + # Given class TestModel(BaseModel): field1: str = "default" - # Act + # When with patch("rich.prompt.Prompt.ask") as mock_prompt: mock_prompt.side_effect = ["value1"] result = ask_user_for_model_fields(TestModel) - # Assert + # Then assert result == {"field1": "value1"}