Skip to content

Commit

Permalink
Use profile specified in --profile with dbt init (#7450)
Browse files Browse the repository at this point in the history
* Use profile specified in --profile with dbt init

* Update .changes/unreleased/Fixes-20230424-161642.yaml

Co-authored-by: Doug Beatty <[email protected]>

* Refactor run() method into functions, replace exit() calls with exceptions

* Update help text for profile option

---------

Co-authored-by: Doug Beatty <[email protected]>
  • Loading branch information
ezraerb and dbeatty10 authored Sep 15, 2023
1 parent f52bd92 commit 3f5ebe8
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 24 deletions.
7 changes: 7 additions & 0 deletions .changes/unreleased/Fixes-20230424-161642.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Fixes
body: If --profile specified with dbt-init, create the project with the specified
profile
time: 2023-04-24T16:16:42.994547-04:00
custom:
Author: ezraerb
Issue: "6154"
2 changes: 1 addition & 1 deletion core/dbt/cli/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
profile = click.option(
"--profile",
envvar=None,
help="Which profile to load. Overrides setting in dbt_project.yml.",
help="Which existing profile to load. Overrides setting in dbt_project.yml.",
)

profiles_dir = click.option(
Expand Down
82 changes: 59 additions & 23 deletions core/dbt/task/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import dbt.config
import dbt.clients.system
from dbt.config.profile import read_profile
from dbt.exceptions import DbtRuntimeError
from dbt.flags import get_flags
from dbt.version import _get_adapter_plugin_names
from dbt.adapters.factory import load_plugin, get_include_paths
Expand Down Expand Up @@ -188,6 +190,15 @@ def create_profile_from_target(self, adapter: str, profile_name: str):
# sample_profiles.yml
self.create_profile_from_sample(adapter, profile_name)

def check_if_profile_exists(self, profile_name: str) -> bool:
"""
Validate that the specified profile exists. Can't use the regular profile validation
routine because it assumes the project file exists
"""
profiles_dir = get_flags().PROFILES_DIR
raw_profiles = read_profile(profiles_dir)
return profile_name in raw_profiles

def check_if_can_write_profile(self, profile_name: Optional[str] = None) -> bool:
"""Using either a provided profile name or that specified in dbt_project.yml,
check if the profile already exists in profiles.yml, and if so ask the
Expand Down Expand Up @@ -233,6 +244,25 @@ def ask_for_adapter_choice(self) -> str:
numeric_choice = click.prompt(prompt_msg, type=click.INT)
return available_adapters[numeric_choice - 1]

def setup_profile(self, profile_name: str) -> None:
"""Set up a new profile for a project"""
fire_event(SettingUpProfile())
if not self.check_if_can_write_profile(profile_name=profile_name):
return
# If a profile_template.yml exists in the project root, that effectively
# overrides the profile_template.yml for the given target.
profile_template_path = Path("profile_template.yml")
if profile_template_path.exists():
try:
# This relies on a valid profile_template.yml from the user,
# so use a try: except to fall back to the default on failure
self.create_profile_using_project_profile_template(profile_name)
return
except Exception:
fire_event(InvalidProfileTemplateYAML())
adapter = self.ask_for_adapter_choice()
self.create_profile_from_target(adapter, profile_name=profile_name)

def get_valid_project_name(self) -> str:
"""Returns a valid project name, either from CLI arg or user prompt."""
name = self.args.project_name
Expand All @@ -247,11 +277,11 @@ def get_valid_project_name(self) -> str:

return name

def create_new_project(self, project_name: str):
def create_new_project(self, project_name: str, profile_name: str):
self.copy_starter_repo(project_name)
os.chdir(project_name)
with open("dbt_project.yml", "r") as f:
content = f"{f.read()}".format(project_name=project_name, profile_name=project_name)
content = f"{f.read()}".format(project_name=project_name, profile_name=profile_name)
with open("dbt_project.yml", "w") as f:
f.write(content)
fire_event(
Expand All @@ -274,9 +304,18 @@ def run(self):
in_project = False

if in_project:
# If --profile was specified, it means use an existing profile, which is not
# applicable to this case
if self.args.profile:
raise DbtRuntimeError(
msg="Can not init existing project with specified profile, edit dbt_project.yml instead"
)

# When dbt init is run inside an existing project,
# just setup the user's profile.
profile_name = self.get_profile_name_from_current_project()
if not self.args.skip_profile_setup:
profile_name = self.get_profile_name_from_current_project()
self.setup_profile(profile_name)
else:
# When dbt init is run outside of an existing project,
# create a new project and set up the user's profile.
Expand All @@ -285,24 +324,21 @@ def run(self):
if project_path.exists():
fire_event(ProjectNameAlreadyExists(name=project_name))
return
self.create_new_project(project_name)
profile_name = project_name

# Ask for adapter only if skip_profile_setup flag is not provided.
if not self.args.skip_profile_setup:
fire_event(SettingUpProfile())
if not self.check_if_can_write_profile(profile_name=profile_name):
return
# If a profile_template.yml exists in the project root, that effectively
# overrides the profile_template.yml for the given target.
profile_template_path = Path("profile_template.yml")
if profile_template_path.exists():
try:
# This relies on a valid profile_template.yml from the user,
# so use a try: except to fall back to the default on failure
self.create_profile_using_project_profile_template(profile_name)
return
except Exception:
fire_event(InvalidProfileTemplateYAML())
adapter = self.ask_for_adapter_choice()
self.create_profile_from_target(adapter, profile_name=profile_name)
# If the user specified an existing profile to use, use it instead of generating a new one
user_profile_name = self.args.profile
if user_profile_name:
if not self.check_if_profile_exists(user_profile_name):
raise DbtRuntimeError(
msg="Could not find profile named '{}'".format(user_profile_name)
)
self.create_new_project(project_name, user_profile_name)
else:
profile_name = project_name
# Create the profile after creating the project to avoid leaving a random profile
# if the former fails.
self.create_new_project(project_name, profile_name)

# Ask for adapter only if skip_profile_setup flag is not provided
if not self.args.skip_profile_setup:
self.setup_profile(profile_name)
147 changes: 147 additions & 0 deletions tests/functional/init/test_init.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import click
import os
import yaml
import pytest
from pathlib import Path
from unittest import mock
from unittest.mock import Mock, call

from dbt.exceptions import DbtRuntimeError

from dbt.tests.util import run_dbt


Expand Down Expand Up @@ -84,6 +87,11 @@ def test_init_task_in_project_with_existing_profiles_yml(
"""
)

def test_init_task_in_project_specifying_profile_errors(self):
with pytest.raises(DbtRuntimeError) as error:
run_dbt(["init", "--profile", "test"], expect_pass=False)
assert "Can not init existing project with specified profile" in str(error)


class TestInitProjectWithoutExistingProfilesYml:
@mock.patch("dbt.task.init._get_adapter_plugin_names")
Expand Down Expand Up @@ -159,6 +167,20 @@ def exists_side_effect(path):
"""
)

@mock.patch.object(Path, "exists", autospec=True)
def test_init_task_in_project_without_profile_yml_specifying_profile_errors(self, exists):
def exists_side_effect(path):
# Override responses on specific files, default to 'real world' if not overriden
return {"profiles.yml": False}.get(path.name, os.path.exists(path))

exists.side_effect = exists_side_effect

# Even through no profiles.yml file exists, the init will not modify project.yml,
# so this errors
with pytest.raises(DbtRuntimeError) as error:
run_dbt(["init", "--profile", "test"], expect_pass=False)
assert "Could not find profile named test" in str(error)


class TestInitProjectWithoutExistingProfilesYmlOrTemplate:
@mock.patch("dbt.task.init._get_adapter_plugin_names")
Expand Down Expand Up @@ -708,3 +730,128 @@ def test_init_inside_project_and_skip_profile_setup(
# skip interactive profile setup
run_dbt(["init", "--skip-profile-setup"])
assert len(manager.mock_calls) == 0


class TestInitOutsideOfProjectWithSpecifiedProfile(TestInitOutsideOfProjectBase):
@mock.patch("dbt.task.init._get_adapter_plugin_names")
@mock.patch("click.prompt")
def test_init_task_outside_of_project_with_specified_profile(
self, mock_prompt, mock_get_adapter, project, project_name, unique_schema, dbt_profile_data
):
manager = Mock()
manager.attach_mock(mock_prompt, "prompt")
manager.prompt.side_effect = [
project_name,
]
mock_get_adapter.return_value = [project.adapter.type()]
run_dbt(["init", "--profile", "test"])

manager.assert_has_calls(
[
call.prompt("Enter a name for your project (letters, digits, underscore)"),
]
)

# profiles.yml is NOT overwritten, so assert that the text matches that of the
# original fixture
with open(os.path.join(project.profiles_dir, "profiles.yml"), "r") as f:
assert f.read() == yaml.safe_dump(dbt_profile_data)

with open(os.path.join(project.project_root, project_name, "dbt_project.yml"), "r") as f:
assert (
f.read()
== f"""
# Name your project! Project names should contain only lowercase characters
# and underscores. A good package name should reflect your organization's
# name or the intended use of these models
name: '{project_name}'
version: '1.0.0'
config-version: 2
# This setting configures which "profile" dbt uses for this project.
profile: 'test'
# These configurations specify where dbt should look for different types of files.
# The `model-paths` config, for example, states that models in this project can be
# found in the "models/" directory. You probably won't need to change these!
model-paths: ["models"]
analysis-paths: ["analyses"]
test-paths: ["tests"]
seed-paths: ["seeds"]
macro-paths: ["macros"]
snapshot-paths: ["snapshots"]
clean-targets: # directories to be removed by `dbt clean`
- "target"
- "dbt_packages"
# Configuring models
# Full documentation: https://docs.getdbt.com/docs/configuring-models
# In this example config, we tell dbt to build all models in the example/
# directory as views. These settings can be overridden in the individual model
# files using the `{{{{ config(...) }}}}` macro.
models:
{project_name}:
# Config indicated by + and applies to all files under models/example/
example:
+materialized: view
"""
)


class TestInitOutsideOfProjectSpecifyingInvalidProfile(TestInitOutsideOfProjectBase):
@mock.patch("dbt.task.init._get_adapter_plugin_names")
@mock.patch("click.prompt")
def test_init_task_outside_project_specifying_invalid_profile_errors(
self, mock_prompt, mock_get_adapter, project, project_name
):
manager = Mock()
manager.attach_mock(mock_prompt, "prompt")
manager.prompt.side_effect = [
project_name,
]
mock_get_adapter.return_value = [project.adapter.type()]

with pytest.raises(DbtRuntimeError) as error:
run_dbt(["init", "--profile", "invalid"], expect_pass=False)
assert "Could not find profile named invalid" in str(error)

manager.assert_has_calls(
[
call.prompt("Enter a name for your project (letters, digits, underscore)"),
]
)


class TestInitOutsideOfProjectSpecifyingProfileNoProfilesYml(TestInitOutsideOfProjectBase):
@mock.patch("dbt.task.init._get_adapter_plugin_names")
@mock.patch("click.prompt")
def test_init_task_outside_project_specifying_profile_no_profiles_yml_errors(
self, mock_prompt, mock_get_adapter, project, project_name
):
manager = Mock()
manager.attach_mock(mock_prompt, "prompt")
manager.prompt.side_effect = [
project_name,
]
mock_get_adapter.return_value = [project.adapter.type()]

# Override responses on specific files, default to 'real world' if not overriden
original_isfile = os.path.isfile
with mock.patch(
"os.path.isfile",
new=lambda path: {"profiles.yml": False}.get(
os.path.basename(path), original_isfile(path)
),
):
with pytest.raises(DbtRuntimeError) as error:
run_dbt(["init", "--profile", "test"], expect_pass=False)
assert "Could not find profile named invalid" in str(error)

manager.assert_has_calls(
[
call.prompt("Enter a name for your project (letters, digits, underscore)"),
]
)

0 comments on commit 3f5ebe8

Please sign in to comment.