Skip to content

Commit

Permalink
fix: init command failing (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
julesbertrand authored Feb 26, 2024
1 parent 5a77249 commit df3fb3a
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 52 deletions.
25 changes: 13 additions & 12 deletions deployer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import typer
from loguru import logger
from pydantic import ValidationError
from rich.prompt import Prompt
from rich.prompt import Confirm, Prompt
from typing_extensions import Annotated

from deployer import constants
Expand Down Expand Up @@ -493,7 +493,9 @@ def create_pipeline(
if existing_pipelines:
raise typer.BadParameter(f"Pipelines {existing_pipelines} already exist.")

logger.info(f"Creating pipeline {pipeline_names} with config type {config_type}")
console.print(
f"Creating pipeline {pipeline_names} with config type: [bold]{config_type}[/bold]"
)

for pipeline_name in pipeline_names:
pipeline_filepath = deployer_settings.pipelines_root_path / f"{pipeline_name}.py"
Expand All @@ -514,7 +516,11 @@ def create_pipeline(
pipeline_filepath.unlink()
raise e

logger.success(f"Pipeline {pipeline_name} created with configs in {config_dirpath}")
console.print(
f"Pipeline '{pipeline_name}' created at '{pipeline_filepath}'"
f" with config files: {[str(p) for p in config_dirpath.glob('*')]}. :sparkles:",
style="blue",
)


@app.command(name="init")
Expand All @@ -524,7 +530,7 @@ def init_deployer(ctx: typer.Context): # noqa: C901
console.print("Welcome to Vertex Deployer!", style="blue")
console.print("This command will help you getting fired up.", style="blue")

if Prompt.ask("Do you want to configure the deployer?", choices=["y", "n"]) == "y":
if Confirm.ask("Do you want to configure the deployer?"):
pyproject_toml_filepath = find_pyproject_toml(Path.cwd().resolve())

if pyproject_toml_filepath is None:
Expand All @@ -542,7 +548,7 @@ def init_deployer(ctx: typer.Context): # noqa: C901
update_pyproject_toml(pyproject_toml_filepath, new_deployer_settings)
console.print("Configuration saved in pyproject.toml :sparkles:", style="blue")

if Prompt.ask("Do you want to build default folder structure", choices=["y", "n"]) == "y":
if Confirm.ask("Do you want to build default folder structure"):

def create_file_or_dir(path: Path, text: str = ""):
"""Create a file (if text is provided) or a directory at path. Warn if path exists."""
Expand All @@ -563,14 +569,13 @@ def create_file_or_dir(path: Path, text: str = ""):
Path("./.env"), "=\n".join(VertexPipelinesSettings.model_json_schema()["required"])
)

if Prompt.ask("Do you want to create a pipeline?", choices=["y", "n"]) == "y":
if Confirm.ask("Do you want to create a pipeline?"):
wrong_name = True
while wrong_name:
pipeline_name = Prompt.ask("What is the name of the pipeline?")
pipeline_path = Path(deployer_settings.pipelines_root_path) / f"{pipeline_name}.py"

try:
create_pipeline(pipeline_name=pipeline_name)
create_pipeline(ctx, pipeline_names=[pipeline_name])
except typer.BadParameter as e:
console.print(e, style="red")
except FileExistsError:
Expand All @@ -580,10 +585,6 @@ def create_file_or_dir(path: Path, text: str = ""):
)
else:
wrong_name = False
console.print(
f"Pipeline '{pipeline_name}' created at '{pipeline_path}'. :sparkles:",
style="blue",
)

console.print("All done :sparkles:", style="blue")

Expand Down
1 change: 0 additions & 1 deletion deployer/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def load_vertex_settings(env_file: Optional[Path] = None) -> VertexPipelinesSett
"""Load the settings from the environment."""
try:
settings = VertexPipelinesSettings(_env_file=env_file, _env_file_encoding="utf-8")
print(settings)
except ValidationError as e:
msg = "Validation failed for VertexPipelinesSettings. "
if env_file is not None:
Expand Down
18 changes: 6 additions & 12 deletions deployer/utils/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import BaseModel
from rich.console import Console
from rich.prompt import Prompt
from rich.prompt import Confirm, Prompt

console = Console()

Expand All @@ -21,10 +21,8 @@ def ask_user_for_model_fields(model: Type[BaseModel]) -> dict:
set_fields = {}
for field_name, field_info in model.model_fields.items():
if isclass(field_info.annotation) and issubclass(field_info.annotation, BaseModel):
answer = Prompt.ask(
f"Do you want to configure command {field_name}?", choices=["y", "n"], default="n"
)
if answer == "y":
answer = Confirm.ask(f"Do you want to configure command {field_name}?", default=False)
if answer:
set_fields[field_name] = ask_user_for_model_fields(field_info.annotation)

else:
Expand All @@ -36,13 +34,9 @@ def ask_user_for_model_fields(model: Type[BaseModel]) -> dict:
choices = list(annotation.__members__)

if isclass(annotation) and annotation == bool:
choices = ["y", "n"]
default = "y" if field_info.default else "n"

answer = Prompt.ask(field_name, default=default, choices=choices)

if isclass(annotation) and annotation == bool:
answer = answer == "y"
answer = Confirm.ask(field_name, default=default)
else:
answer = Prompt.ask(field_name, default=default, choices=choices)

if answer != field_info.default:
set_fields[field_name] = answer
Expand Down
58 changes: 31 additions & 27 deletions tests/unit_tests/test_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,38 @@

class TestAskUserForModelFields:
def test_ask_user_for_input_for_each_field(self):
# Arrange
# Given
class TestModel(BaseModel):
field1: str
field2: int
field3: bool = False

# Act
# When
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["value1", 2, "y"]
result = ask_user_for_model_fields(TestModel)
with patch("rich.prompt.Confirm.ask") as mock_confirm:
mock_confirm.side_effect = [True]
mock_prompt.side_effect = ["value1", 2]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": "value1", "field2": 2, "field3": True}

def test_ask_user_with_boolean_fields(self):
# Arrange
# Given
class TestModel(BaseModel):
field1: bool
field2: bool = True

# Act
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["y", "n"]
# When
with patch("rich.prompt.Confirm.ask") as mock_confirm:
mock_confirm.side_effect = [True, False]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": True, "field2": False}

def test_ask_user_with_enum_fields(self):
# Arrange
# Given
class TestEnum(Enum):
OPTION1 = "Option 1"
OPTION2 = "Option 2"
Expand All @@ -46,64 +48,66 @@ class TestEnum(Enum):
class TestModel(BaseModel):
field1: TestEnum

# Act
# When
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["Option 2"]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": "Option 2"}

def test_ask_user_with_nested_models(self):
# Arrange
# Given
class NestedModel(BaseModel):
nested_field1: str
nested_field2: int

class TestModel(BaseModel):
field1: NestedModel

# Act
# When
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["y", "value1", 2]
result = ask_user_for_model_fields(TestModel)
with patch("rich.prompt.Confirm.ask") as mock_confirm:
mock_confirm.side_effect = [True]
mock_prompt.side_effect = ["value1", 2]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": {"nested_field1": "value1", "nested_field2": 2}}

def test_ask_user_with_no_fields(self):
# Arrange
# Given
class TestModel(BaseModel):
pass

# Act
# When
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {}

def test_ask_user_with_no_default_value_and_no_valid_choices(self):
# Arrange
# Given
class TestModel(BaseModel):
field1: str

# Act
# When
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["value1"]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": "value1"}

def test_ask_user_with_default_value_and_no_valid_choices(self):
# Arrange
# Given
class TestModel(BaseModel):
field1: str = "default"

# Act
# When
with patch("rich.prompt.Prompt.ask") as mock_prompt:
mock_prompt.side_effect = ["value1"]
result = ask_user_for_model_fields(TestModel)

# Assert
# Then
assert result == {"field1": "value1"}

0 comments on commit df3fb3a

Please sign in to comment.