From ed3f53311311742b18e86464effe8d72051a267b Mon Sep 17 00:00:00 2001 From: merav-aharoni Date: Tue, 19 Sep 2023 18:51:28 +0300 Subject: [PATCH] Allow user to define a default account as an environment variable (#1018) * Allow user to define a default account as an environment variable * Fixed test * Fixed mistaken paste * Cleaned up test * Moved test to TestAccountManager * Added ability to define default channel in save_account * Cleaned up code, fixed bugs * Changed name of parameter * Added test. Cleaned up code surrounding preferences of channel selection * black and lint * Fixed bug when json file was empty * Code cleanup and documentation * Documentation * Removed channel from condition, because unnecessary * changed default_channel to default_account * Changed saving and getting default channel to default account * black * Documentation * Release notes * Reverted diff that was unnecessary --------- Co-authored-by: Kevin Tian --- qiskit_ibm_runtime/accounts/management.py | 67 ++++++- qiskit_ibm_runtime/accounts/storage.py | 22 ++- qiskit_ibm_runtime/qiskit_runtime_service.py | 6 + .../default_account-13d86d50f5b1d972.yaml | 6 + test/account.py | 3 + test/unit/test_account.py | 185 +++++++++++++++++- 6 files changed, 273 insertions(+), 16 deletions(-) create mode 100644 releasenotes/notes/default_account-13d86d50f5b1d972.yaml diff --git a/qiskit_ibm_runtime/accounts/management.py b/qiskit_ibm_runtime/accounts/management.py index 0e35df566..fd65c6e9b 100644 --- a/qiskit_ibm_runtime/accounts/management.py +++ b/qiskit_ibm_runtime/accounts/management.py @@ -48,6 +48,7 @@ def save( verify: Optional[bool] = None, overwrite: Optional[bool] = False, channel_strategy: Optional[str] = None, + set_as_default: Optional[bool] = None, ) -> None: """Save account on disk.""" channel = channel or os.getenv("QISKIT_IBM_CHANNEL") or _DEFAULT_CHANNEL_TYPE @@ -69,6 +70,7 @@ def save( ) # avoid storing invalid accounts .validate().to_saved_format(), + set_as_default=set_as_default, ) @staticmethod @@ -137,6 +139,21 @@ def get( filename: Full path of the file from which to get the account. name: Account name. channel: Channel type. + Order of precedence for selecting the account: + 1. If name is specified, get account with that name + 2. If the environment variables define an account, get that one + 3. If the channel parameter is defined, + a. get the account of this channel type defined as "is_default_account" + b. get the account of this channel type with default name + c. get any account of this channel type + 4. If the channel is defined in "QISKIT_IBM_CHANNEL" + a. get the account of this channel type defined as "is_default_account" + b. get the account of this channel type with default name + c. get any account of this channel type + 5. If a default account is defined in the json file, get that account + 6. Get any account that is defined in the json file with + preference for _DEFAULT_CHANNEL_TYPE. + Returns: Account information. @@ -157,18 +174,20 @@ def get( if env_account is not None: return env_account - if channel: - saved_account = read_config( - filename=filename, - name=cls._get_default_account_name(channel=channel), - ) - if saved_account is None: - if os.path.isfile(_QISKITRC_CONFIG_FILE): - return cls._from_qiskitrc_file() - raise AccountNotFoundError(f"No default {channel} account saved.") + all_config = read_config(filename=filename) + # Get the default account for the given channel. + # If channel == None, get the default account, for any channel, if it exists + saved_account = cls._get_default_account(all_config, channel) + + if saved_account is not None: return Account.from_saved_format(saved_account) - all_config = read_config(filename=filename) + # Get the default account from the channel defined in the environment variable + account = cls._get_default_account(all_config, channel=channel_) + if account is not None: + return Account.from_saved_format(account) + + # check for any account for channel_type in _CHANNEL_TYPES: account_name = cls._get_default_account_name(channel=channel_type) if account_name in all_config: @@ -209,6 +228,34 @@ def _from_env_variables(cls, channel: Optional[ChannelType]) -> Optional[Account channel=channel, ) + @classmethod + def _get_default_account( + cls, all_config: dict, channel: Optional[str] = None + ) -> Optional[dict]: + default_channel_account = None + any_channel_account = None + + for account_name in all_config: + account = all_config[account_name] + if channel: + if account.get("channel") == channel and account.get("is_default_account"): + return account + if account.get( + "channel" + ) == channel and account_name == cls._get_default_account_name(channel): + default_channel_account = account + if account.get("channel") == channel: + any_channel_account = account + else: + if account.get("is_default_account"): + return account + + if default_channel_account: + return default_channel_account + elif any_channel_account: + return any_channel_account + return None + @classmethod def _get_default_account_name(cls, channel: ChannelType) -> str: return ( diff --git a/qiskit_ibm_runtime/accounts/storage.py b/qiskit_ibm_runtime/accounts/storage.py index db463de27..256432997 100644 --- a/qiskit_ibm_runtime/accounts/storage.py +++ b/qiskit_ibm_runtime/accounts/storage.py @@ -22,7 +22,9 @@ logger = logging.getLogger(__name__) -def save_config(filename: str, name: str, config: dict, overwrite: bool) -> None: +def save_config( + filename: str, name: str, config: dict, overwrite: bool, set_as_default: Optional[bool] = None +) -> None: """Save configuration data in a JSON file under the given name.""" logger.debug("Save configuration data for '%s' in '%s'", name, filename) _ensure_file_exists(filename) @@ -35,8 +37,24 @@ def save_config(filename: str, name: str, config: dict, overwrite: bool) -> None f"Named account ({name}) already exists. " f"Set overwrite=True to overwrite." ) + data[name] = config + + # if set_as_default, but another account is defined as default, user must specify overwrite to change + # the default account. + if set_as_default: + data[name]["is_default_account"] = True + for account_name in data: + account = data[account_name] + if account_name != name and account.get("is_default_account"): + if overwrite: + del account["is_default_account"] + else: + raise AccountAlreadyExistsError( + f"default_account ({name}) already exists. " + f"Set overwrite=True to overwrite." + ) + with open(filename, mode="w", encoding="utf-8") as json_out: - data[name] = config json.dump(data, json_out, sort_keys=True, indent=4) diff --git a/qiskit_ibm_runtime/qiskit_runtime_service.py b/qiskit_ibm_runtime/qiskit_runtime_service.py index 76e39e562..ba22c8c9f 100644 --- a/qiskit_ibm_runtime/qiskit_runtime_service.py +++ b/qiskit_ibm_runtime/qiskit_runtime_service.py @@ -139,6 +139,7 @@ def __init__( - Account with the input `name`, if specified. - Default account for the `channel` type, if `channel` is specified but `token` is not. - Account defined by the input `channel` and `token`, if specified. + - Account defined by the `default_channel` if defined in filename - Account defined by the environment variables, if defined. - Default account for the ``ibm_cloud`` account, if one is available. - Default account for the ``ibm_quantum`` account, if one is available. @@ -287,6 +288,7 @@ def _discover_account( "'channel' is required if 'token', or 'url' is specified but 'name' is not." ) + # channel is not defined yet, get it from the AccountManager if account is None: account = AccountManager.get(filename=filename) @@ -689,6 +691,7 @@ def save_account( verify: Optional[bool] = None, overwrite: Optional[bool] = False, channel_strategy: Optional[str] = None, + set_as_default: Optional[bool] = None, ) -> None: """Save the account to disk for future use. @@ -709,6 +712,8 @@ def save_account( verify: Verify the server's TLS certificate. overwrite: ``True`` if the existing account is to be overwritten. channel_strategy: Error mitigation strategy. + set_as_default: If ``True``, the account is saved in filename, + as the default account. """ AccountManager.save( @@ -722,6 +727,7 @@ def save_account( verify=verify, overwrite=overwrite, channel_strategy=channel_strategy, + set_as_default=set_as_default, ) @staticmethod diff --git a/releasenotes/notes/default_account-13d86d50f5b1d972.yaml b/releasenotes/notes/default_account-13d86d50f5b1d972.yaml new file mode 100644 index 000000000..47d3bbe7e --- /dev/null +++ b/releasenotes/notes/default_account-13d86d50f5b1d972.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added the option to define a default account in the account json file. + The select an account as default, define ``set_as_default=True`` in + ``QiskitRuntimeService.save_account()``. diff --git a/test/account.py b/test/account.py index 1987eb591..65da49c73 100644 --- a/test/account.py +++ b/test/account.py @@ -152,6 +152,7 @@ def get_account_config_contents( instance=None, verify=None, proxies=None, + set_default=None, ): """Generate qiskitrc content""" if instance is None: @@ -177,4 +178,6 @@ def get_account_config_contents( out[name]["verify"] = verify if proxies is not None: out[name]["proxies"] = proxies + if set_default: + out[name]["is_default_account"] = True return out diff --git a/test/unit/test_account.py b/test/unit/test_account.py index 58ec1432a..d389123f9 100644 --- a/test/unit/test_account.py +++ b/test/unit/test_account.py @@ -361,7 +361,7 @@ def test_delete(self): def test_delete_filename(self): """Test delete accounts with filename parameter.""" - filename = "~/account_to_delete.json" + filename = _TEST_FILENAME name = "key1" channel = "ibm_quantum" AccountManager.save(channel=channel, filename=filename, name=name, token="temp_token") @@ -387,6 +387,180 @@ def test_account_with_filename(self): ) self.assertEqual(account.token, dummy_token) + @temporary_account_config_file() + def test_default_env_channel(self): + """Test that if QISKIT_IBM_CHANNEL is set in the environment, this channel will be used""" + token = uuid.uuid4().hex + # unset default_channel in the environment + with temporary_account_config_file(token=token), no_envs("QISKIT_IBM_CHANNEL"): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_cloud") + + # set channel to default channel in the environment + subtests = ["ibm_quantum", "ibm_cloud"] + for channel in subtests: + channel_env = {"QISKIT_IBM_CHANNEL": channel} + with temporary_account_config_file(channel=channel, token=token), custom_envs( + channel_env + ): + service = FakeRuntimeService() + self.assertEqual(service.channel, channel) + + def test_save_default_account(self): + """Test that if a default_account is defined in the qiskit-ibm.json file, + this account will be used""" + AccountManager.save( + filename=_TEST_FILENAME, + name=_DEFAULT_ACCOUNT_NAME_IBM_CLOUD, + token=_TEST_IBM_CLOUD_ACCOUNT.token, + url=_TEST_IBM_CLOUD_ACCOUNT.url, + instance=_TEST_IBM_CLOUD_ACCOUNT.instance, + channel="ibm_cloud", + overwrite=True, + set_as_default=True, + ) + AccountManager.save( + filename=_TEST_FILENAME, + name=_DEFAULT_ACCOUNT_NAME_IBM_QUANTUM, + token=_TEST_IBM_QUANTUM_ACCOUNT.token, + url=_TEST_IBM_QUANTUM_ACCOUNT.url, + instance=_TEST_IBM_QUANTUM_ACCOUNT.instance, + channel="ibm_quantum", + overwrite=True, + ) + + with no_envs("QISKIT_IBM_CHANNEL"), no_envs("QISKIT_IBM_TOKEN"): + account = AccountManager.get(filename=_TEST_FILENAME) + self.assertEqual(account.channel, "ibm_cloud") + self.assertEqual(account.token, _TEST_IBM_CLOUD_ACCOUNT.token) + + AccountManager.save( + filename=_TEST_FILENAME, + name=_DEFAULT_ACCOUNT_NAME_IBM_QUANTUM, + token=_TEST_IBM_QUANTUM_ACCOUNT.token, + url=_TEST_IBM_QUANTUM_ACCOUNT.url, + instance=_TEST_IBM_QUANTUM_ACCOUNT.instance, + channel="ibm_quantum", + overwrite=True, + set_as_default=True, + ) + with no_envs("QISKIT_IBM_CHANNEL"), no_envs("QISKIT_IBM_TOKEN"): + account = AccountManager.get(filename=_TEST_FILENAME) + self.assertEqual(account.channel, "ibm_quantum") + self.assertEqual(account.token, _TEST_IBM_QUANTUM_ACCOUNT.token) + + @temporary_account_config_file() + def test_set_channel_precedence(self): + """Test the precedence of the various methods to set the account: + account name > env_variables > channel parameter default account + > default account > default account from default channel""" + cloud_token = uuid.uuid4().hex + default_token = uuid.uuid4().hex + preferred_token = uuid.uuid4().hex + any_token = uuid.uuid4().hex + channel_env = {"QISKIT_IBM_CHANNEL": "ibm_cloud"} + contents = { + _DEFAULT_ACCOUNT_NAME_IBM_CLOUD: { + "channel": "ibm_cloud", + "token": cloud_token, + "instance": "some_instance", + }, + _DEFAULT_ACCOUNT_NAME_IBM_QUANTUM: { + "channel": "ibm_quantum", + "token": default_token, + }, + "preferred-ibm-quantum": { + "channel": "ibm_quantum", + "token": preferred_token, + "is_default_account": True, + }, + "any-quantum": { + "channel": "ibm_quantum", + "token": any_token, + }, + } + + # 'name' parameter + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService(name="any-quantum") + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, any_token) + + # No name or channel params, no env vars, get the account specified as "is_default_account" + with temporary_account_config_file(contents=contents), no_envs( + "QISKIT_IBM_CHANNEL" + ), no_envs("QISKIT_IBM_TOKEN"): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, preferred_token) + + # parameter 'channel' is specified, it overrides channel in env + # account specified as "is_default_account" + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService(channel="ibm_quantum") + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, preferred_token) + + # account with default name for the channel + contents["preferred-ibm-quantum"]["is_default_account"] = False + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService(channel="ibm_quantum") + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, default_token) + + # any account for this channel + del contents["default-ibm-quantum"] + # channel_env = {"QISKIT_IBM_CHANNEL": "ibm_quantum"} + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService(channel="ibm_quantum") + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, any_token) + + # no channel param, get account that is specified as "is_default_account" + # for channel from env + contents["preferred-ibm-quantum"]["is_default_account"] = True + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, preferred_token) + + # no channel param, account with default name for the channel from env + del contents["preferred-ibm-quantum"]["is_default_account"] + contents["default-ibm-quantum"] = { + "channel": "ibm_quantum", + "token": default_token, + } + channel_env = {"QISKIT_IBM_CHANNEL": "ibm_quantum"} + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, default_token) + + # no channel param, any account for the channel from env + del contents["default-ibm-quantum"] + with temporary_account_config_file(contents=contents), custom_envs(channel_env), no_envs( + "QISKIT_IBM_TOKEN" + ): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_quantum") + self.assertEqual(service._account.token, any_token) + # default channel + with temporary_account_config_file(contents=contents), no_envs("QISKIT_IBM_CHANNEL"): + service = FakeRuntimeService() + self.assertEqual(service.channel, "ibm_cloud") + def tearDown(self) -> None: """Test level tear down.""" super().tearDown() @@ -516,7 +690,10 @@ def test_enable_account_both_channel(self): token = uuid.uuid4().hex contents = get_account_config_contents(channel="ibm_cloud", token=token) contents.update(get_account_config_contents(channel="ibm_quantum", token=uuid.uuid4().hex)) - with temporary_account_config_file(contents=contents), no_envs(["QISKIT_IBM_TOKEN"]): + + with temporary_account_config_file(contents=contents), no_envs( + ["QISKIT_IBM_TOKEN", "QISKIT_IBM_CHANNEL"] + ): service = FakeRuntimeService() self.assertTrue(service._account) self.assertEqual(service._account.token, token) @@ -535,7 +712,7 @@ def test_enable_account_by_env_channel(self): "QISKIT_IBM_URL": url, "QISKIT_IBM_INSTANCE": "h/g/p" if channel == "ibm_quantum" else "crn:12", } - with custom_envs(envs): + with custom_envs(envs), no_envs("QISKIT_IBM_CHANNEL"): service = FakeRuntimeService(channel=channel) self.assertTrue(service._account) @@ -652,7 +829,7 @@ def test_enable_account_by_env_pref(self): "QISKIT_IBM_URL": url, "QISKIT_IBM_INSTANCE": "my_crn", } - with custom_envs(envs): + with custom_envs(envs), no_envs("QISKIT_IBM_CHANNEL"): service = FakeRuntimeService(**extra) self.assertTrue(service._account)