diff --git a/activity_browser/ui/wizards/db_import_wizard.py b/activity_browser/ui/wizards/db_import_wizard.py index 2ac185050..78ea9510a 100644 --- a/activity_browser/ui/wizards/db_import_wizard.py +++ b/activity_browser/ui/wizards/db_import_wizard.py @@ -1,20 +1,32 @@ # -*- coding: utf-8 -*- import io +import os.path +import shutil +import typing +from functools import lru_cache import subprocess import tempfile import zipfile from pathlib import Path -import eidl +import bw2data.errors +import ecoinvent_interface import requests +from bw2data.subclass_mapping import DATABASE_BACKEND_MAPPING from bw2io import BW2Package, SingleOutputEcospold2Importer from bw2io.extractors import Ecospold2DataExtractor from PySide2 import QtCore, QtWidgets from PySide2.QtCore import Signal, Slot +from bw2io.importers import Ecospold2BiosphereImporter +from py7zr import py7zr -from activity_browser import log +from activity_browser import log, project_settings from activity_browser.bwutils import errors from activity_browser.mod import bw2data as bd +from activity_browser.mod.bw2data import databases +from activity_browser.bwutils.ecoinvent_biosphere_versions.ecospold2biosphereimporter import ( + ABEcospold2BiosphereImporter, +) from ...bwutils.importers import ABExcelImporter, ABPackage from ...info import __ei_versions__ @@ -22,9 +34,7 @@ from ..style import style_group_box from ..threading import ABThread from ..widgets import DatabaseLinkingDialog - -# TODO: Rework the entire import wizard, the amount of different classes -# and interwoven connections makes the entire thing nearly incomprehensible. +from ..widgets.biosphere_update import UpdateBiosphereThread class DatabaseImportWizard(QtWidgets.QWizard): @@ -33,13 +43,14 @@ class DatabaseImportWizard(QtWidgets.QWizard): LOCAL_TYPE = 3 EI_LOGIN = 4 EI_VERSION = 5 - ARCHIVE = 6 - DIR = 7 - LOCAL = 8 - EXCEL = 9 - DB_NAME = 10 - CONFIRM = 11 - IMPORT = 12 + DB_BIOSPHERE_CREATION = 6 + ARCHIVE = 7 + DIR = 8 + LOCAL = 9 + EXCEL = 10 + DB_NAME = 11 + CONFIRM = 12 + IMPORT = 13 def __init__(self, parent=None): super().__init__(parent) @@ -56,6 +67,7 @@ def __init__(self, parent=None): self.local_page = LocalImportPage(self) self.ecoinvent_login_page = EcoinventLoginPage(self) self.ecoinvent_version_page = EcoinventVersionPage(self) + self.biosphere_database_setup = BiosphereDatabaseSetup(self) self.archive_page = Choose7zArchivePage(self) self.choose_dir_page = ChooseDirPage(self) self.local_import_page = LocalDatabaseImportPage(self) @@ -68,6 +80,7 @@ def __init__(self, parent=None): self.setPage(self.LOCAL_TYPE, self.local_page) self.setPage(self.EI_LOGIN, self.ecoinvent_login_page) self.setPage(self.EI_VERSION, self.ecoinvent_version_page) + self.setPage(self.DB_BIOSPHERE_CREATION, self.biosphere_database_setup) self.setPage(self.ARCHIVE, self.archive_page) self.setPage(self.DIR, self.choose_dir_page) self.setPage(self.LOCAL, self.local_import_page) @@ -93,9 +106,14 @@ def version(self): def system_model(self): return self.ecoinvent_version_page.system_model_combobox.currentText() + @property + def release_type(self): + return ecoinvent_interface.ReleaseType.ecospold + def update_downloader(self): self.downloader.version = self.version self.downloader.system_model = self.system_model + self.downloader.release_type = self.release_type def done(self, result: int): """ @@ -157,6 +175,12 @@ def cleanup(self): self.import_page.complete = False self.reject() + def has_existing_remote_credentials(self) -> bool: + return ( + self.downloader.username is not None + and self.downloader.password is not None + ) + @Slot(tuple, name="showMessage") def show_info(self, info: tuple) -> None: title, message = info @@ -214,6 +238,7 @@ def __init__(self, parent=None): self.wizard = parent self.radio_buttons = [QtWidgets.QRadioButton(o[0]) for o in self.OPTIONS] self.radio_buttons[0].setChecked(True) + self.has_valid_remote_creds = False layout = QtWidgets.QVBoxLayout() box = QtWidgets.QGroupBox("Data source:") @@ -225,10 +250,21 @@ def __init__(self, parent=None): layout.addWidget(box) self.setLayout(layout) + def validatePage(self): + if ( + self.wizard.has_existing_remote_credentials() + and self.radio_buttons[0].isChecked() + ): + self.has_valid_remote_creds, _ = self.wizard.downloader.login() + return True + def nextId(self): option_id = [b.isChecked() for b in self.radio_buttons].index(True) self.wizard.import_type = self.OPTIONS[option_id][1] - return self.OPTIONS[option_id][2] + next_id = self.OPTIONS[option_id][2] + if next_id == DatabaseImportWizard.EI_LOGIN and self.has_valid_remote_creds: + return DatabaseImportWizard.EI_VERSION + return next_id class LocalImportPage(QtWidgets.QWizardPage): @@ -344,13 +380,25 @@ def __init__(self, parent=None): self.setLayout(layout) def initializePage(self): - self.stored_dbs = eidl.eidlstorage.stored_dbs + self.stored_dbs = ecoinvent_interface.CachedStorage() self.stored_combobox.clear() - self.stored_combobox.addItems(sorted(self.stored_dbs.keys())) + self.stored_combobox.addItems( + sorted( + [ + key + for key, value in self.stored_dbs.catalogue.items() + if value["extracted"] == False + and value["kind"] == "release" + and key.partition(value["system_model"])[2] == "_ecoSpold02.7z" + ] + ) + ) @Slot(int, name="updateSelectedIndex") def update_stored(self, index: int) -> None: - self.path_edit.setText(self.stored_dbs[self.stored_combobox.currentText()]) + self.path_edit.setText( + self.stored_dbs.catalogue[self.stored_combobox.currentText()]["path"] + ) @Slot(name="getArchiveFile") def get_archive(self) -> None: @@ -476,8 +524,12 @@ def initializePage(self): ) else: self.path_label.setText( - "Ecoinvent version: {}
Ecoinvent system model: {}".format( - self.wizard.version, self.wizard.system_model + "Ecoinvent version: {}
" + "Ecoinvent system model: {}
" + "Dependent Database: {}".format( + self.wizard.version, + self.wizard.system_model, + bd.preferences["biosphere_database"], ) ) @@ -738,7 +790,7 @@ def report_failed_unarchive(self, file: str) -> None: class MainWorkerThread(ABThread): - def __init__(self, downloader, parent=None): + def __init__(self, downloader: "ABEcoinventDownloader", parent=None): super().__init__(parent) self.downloader = downloader self.forwast_url = ( @@ -784,19 +836,19 @@ def run_safely(self): def run_ecoinvent(self) -> None: """Run the ecoinvent downloader from start to finish.""" - self.downloader.outdir = eidl.eidlstorage.eidl_dir - if self.downloader.check_stored(): - import_signals.download_complete.emit() - else: - self.run_download() + archive_file = self.run_download() - with tempfile.TemporaryDirectory() as tempdir: - temp_dir = Path(tempdir) - if not import_signals.cancel_sentinel: - self.run_extract(temp_dir) - if not import_signals.cancel_sentinel: - dataset_dir = temp_dir.joinpath("datasets") - self.run_import(dataset_dir) + if os.path.isdir(archive_file): + import_signals.unarchive_finished.emit() + self.run_import(archive_file.joinpath("datasets")) + else: + with tempfile.TemporaryDirectory() as tempdir: + temp_dir = Path(tempdir) + if not import_signals.cancel_sentinel: + self.run_extract(archive_file, temp_dir) + if not import_signals.cancel_sentinel: + dataset_dir = temp_dir.joinpath("datasets") + self.run_import(dataset_dir) def run_forwast(self) -> None: """Adapted from pjamesjoyce/lcopt.""" @@ -823,17 +875,23 @@ def run_forwast(self) -> None: else: self.delete_canceled_db() - def run_download(self) -> None: + def run_download(self) -> Path: """Use the connected ecoinvent downloader.""" - self.downloader.download() + filepath = self.downloader.download() import_signals.download_complete.emit() + return filepath - def run_extract(self, temp_dir: Path) -> None: + def run_extract(self, archive_file: Path, temp_dir: Path) -> None: """Use the connected ecoinvent downloader to extract the downloaded 7zip file. """ - self.downloader.extract(target_dir=temp_dir) - import_signals.unarchive_finished.emit() + try: + self.downloader.extract(archive_file, temp_dir) + except Exception: + import_signals.cancel_sentinel = True + import_signals.unarchive_failed.emit(temp_dir) + else: + import_signals.unarchive_finished.emit() def run_extract_import(self) -> None: """Combine the extract and import steps when beginning from a selected @@ -866,6 +924,7 @@ def run_import(self, import_dir: Path) -> None: signal=import_signals.strategy_progress, ) importer.apply_strategies() + # backend is a custom implementation that wraps sqlite database importer.write_database(backend="activitybrowser") if not import_signals.cancel_sentinel: import_signals.finished.emit() @@ -873,7 +932,7 @@ def run_import(self, import_dir: Path) -> None: self.delete_canceled_db() except errors.ImportCanceledError: self.delete_canceled_db() - except errors.InvalidExchange: + except bw2data.errors.InvalidExchange: # Likely caused by new version of ecoinvent not finding required # biosphere flows. self.delete_canceled_db() @@ -970,11 +1029,20 @@ def __init__(self, parent=None): super().__init__(parent) self.wizard = parent self.complete = False + eco_settings = ecoinvent_interface.Settings() self.username_edit = QtWidgets.QLineEdit() - self.username_edit.setPlaceholderText("ecoinvent username") + if eco_settings.username: + self.username_edit.setText(eco_settings.username) + else: + self.username_edit.setPlaceholderText("ecoinvent username") self.password_edit = QtWidgets.QLineEdit() - self.password_edit.setPlaceholderText("ecoinvent password"), self.password_edit.setEchoMode(QtWidgets.QLineEdit.Password) + if eco_settings.password: + self.password_edit.setText(eco_settings.password) + else: + self.password_edit.setPlaceholderText("ecoinvent password") + self.save_creds = QtWidgets.QPushButton("Save Credentials") + self.save_creds.clicked.connect(self.save_credentials) self.login_button = QtWidgets.QPushButton("login") self.login_button.clicked.connect(self.login) self.password_edit.returnPressed.connect(self.login_button.click) @@ -989,6 +1057,7 @@ def __init__(self, parent=None): box_layout.addWidget(self.password_edit) hlay = QtWidgets.QHBoxLayout() hlay.addWidget(self.login_button) + hlay.addWidget(self.save_creds) hlay.addStretch(1) box_layout.addLayout(hlay) box_layout.addWidget(self.success_label) @@ -1018,6 +1087,13 @@ def login(self) -> None: self.login_thread.update(self.username, self.password) self.login_thread.start() + @Slot(name="SaveEiCredentials") + def save_credentials(self): + self.success_label.setText("Saving Credentials") + ecoinvent_interface.permanent_setting("username", self.username) + ecoinvent_interface.permanent_setting("password", self.password) + self.success_label.setText("Saved Credentials") + @Slot(bool, name="handleLoginResponse") def login_response(self, success: bool) -> None: if not success: @@ -1039,7 +1115,7 @@ def nextId(self): class LoginThread(QtCore.QThread): - def __init__(self, downloader, parent=None): + def __init__(self, downloader: "ABEcoinventDownloader", parent=None): super().__init__(parent) self.downloader = downloader @@ -1048,18 +1124,36 @@ def update(self, username: str, password: str) -> None: self.downloader.password = password def run(self): - self.downloader.login() + error_message = None + try: + login_success, error_message = self.downloader.login() + except Exception as e: + log.error(str(e), exc_info=True) + import_signals.login_success.emit(False) + msg = str(e) + cs = ecoinvent_interface.CachedStorage() + if len(cs.catalogue) > 0: + msg += ( + "\n\nIf you work offline you can use your previously downloaded databases" + + " via the archive option of the import wizard." + ) + import_signals.connection_problem.emit(("Unexpected error", msg)) + else: + import_signals.login_success.emit(login_success) + finally: + if error_message: + import_signals.connection_problem.emit(error_message) class EcoinventVersionPage(QtWidgets.QWizardPage): def __init__(self, parent=None): super().__init__(parent) - self.wizard = self.parent() + self.wizard: "DatabaseImportWizard" = self.parent() self.description_label = QtWidgets.QLabel( "Choose ecoinvent version and system model:" ) self.db_dict = None - self.system_models = {} + self.requires_database_creation = False self.version_combobox = QtWidgets.QComboBox() self.version_combobox.currentTextChanged.connect( self.update_system_model_combobox @@ -1075,27 +1169,17 @@ def __init__(self, parent=None): self.setLayout(layout) def initializePage(self): - if self.db_dict is None: - self.wizard.downloader.db_dict = ( - self.wizard.downloader.get_available_files() - ) - self.db_dict = self.wizard.downloader.db_dict - self.system_models = { - version: sorted( - {k[1] for k in self.db_dict.keys() if k[0] == version}, reverse=True - ) - for version in sorted( - {k[0] for k in self.db_dict.keys() if k[0] in __ei_versions__}, - reverse=True, - ) - } + available_versions = self.wizard.downloader.list_versions() + shown_versions = set( + [version for version in available_versions if version in __ei_versions__] + ) # Catch for incorrect 'universal' key presence # (introduced in version 3.6 of ecoinvent) - if "universal" in self.system_models: - del self.system_models["universal"] + if "universal" in shown_versions: + shown_versions.remove("universal") self.version_combobox.clear() self.system_model_combobox.clear() - versions = sort_semantic_versions(self.system_models.keys()) + versions = sort_semantic_versions(shown_versions) self.version_combobox.addItems(versions) if bool(self.version_combobox.count()): # Adding the items will cause system_model_combobox to update @@ -1111,7 +1195,17 @@ def initializePage(self): ) self.wizard.back() + def validatePage(self): + version = self.version_combobox.currentText() + bd.preferences["biosphere_database"] = "ecoinvent-{}-biosphere".format(version) + bd.preferences.flush() + if bd.preferences["biosphere_database"] not in databases: + self.requires_database_creation = True + return True + def nextId(self): + if self.requires_database_creation: + return DatabaseImportWizard.DB_BIOSPHERE_CREATION return DatabaseImportWizard.DB_NAME @Slot(str) @@ -1120,7 +1214,103 @@ def update_system_model_combobox(self, version: str) -> None: different ecoinvent version. """ self.system_model_combobox.clear() - self.system_model_combobox.addItems(self.system_models[version]) + items = self.wizard.downloader.list_system_models(version) + items = sorted(items, reverse=True) + self.system_model_combobox.addItems(items) + + +class VersionedBiosphereThread(UpdateBiosphereThread): + update = Signal(int, str) + + def __init__(self, version, parent=None): + # reduce biosphere update list up to the selected version + sorted_versions = sort_semantic_versions( + __ei_versions__, highest_to_lowest=False + ) + ei_versions = sorted_versions[: sorted_versions.index(version) + 1] + super().__init__(ei_versions, parent=parent) + self.version = version + + def run_safely(self): + project = f"{bd.projects.current}" + if bd.preferences["biosphere_database"] not in bd.databases: + self.update.emit( + 0, + "Creating {} database for {}".format( + bd.preferences["biosphere_database"], project + ), + ) + self.create_biosphere3_database() + project_settings.add_db(bd.preferences["biosphere_database"]) + + self.update.emit( + 1, + "Updating biosphere database", + ) + super().run_safely() + + def create_biosphere3_database(self): + if self.version == sort_semantic_versions(__ei_versions__)[0][:3]: + eb = Ecospold2BiosphereImporter( + name=bd.preferences["biosphere_database"], version=self.version + ) + else: + eb = ABEcospold2BiosphereImporter( + name=bd.preferences["biosphere_database"], version=self.version + ) + eb.apply_strategies() + eb.write_database() + + +class BiosphereDatabaseSetup(QtWidgets.QWizardPage): + def __init__(self, parent=None): + super().__init__(parent=parent) + self.wizard: "DatabaseImportWizard" = self.parent() + self.update_label = QtWidgets.QLabel() + self.progressbar = QtWidgets.QProgressBar() + self.progressbar.setRange(0, 2) + self.complete = False + + box = QtWidgets.QGroupBox("Creating biosphere database") + box_layout = QtWidgets.QVBoxLayout() + box_layout.addWidget(self.progressbar) + box_layout.addWidget(self.update_label) + box.setLayout(box_layout) + box.setStyleSheet(style_group_box.border_title) + layout = QtWidgets.QVBoxLayout() + layout.addWidget(box) + self.setLayout(layout) + + def isComplete(self): + return self.complete + + def initializePage(self): + self.biosphere_thread = VersionedBiosphereThread(self.wizard.version, self) + self.biosphere_thread.update.connect(self.update_progress) + self.biosphere_thread.finished.connect(self.thread_finished) + self.biosphere_thread.start() + + def validatePage(self): + return self.biosphere_thread.isFinished() + + @Slot(int, str, name="updateThread") + def update_progress(self, current: int, text: str) -> None: + self.progressbar.setValue(current) + self.update_label.setText(text) + + @Slot(int, name="threadFinished") + def thread_finished(self, result: int = None) -> None: + self.progressbar.setMaximum(1) + self.progressbar.setValue(1) + if result and result != 0: + self.update_label.setText("Something went wrong...") + else: + self.update_label.setText("All Done") + self.complete = True + self.completeChanged.emit() + + def nextId(self): + return DatabaseImportWizard.DB_NAME class LocalDatabaseImportPage(QtWidgets.QWizardPage): @@ -1276,22 +1466,28 @@ def extract(cls, dirpath: str, db_name: str, *args, **kwargs): class ActivityBrowserBackend(bd.backends.SQLiteBackend): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._ab_current_index = 0 + self._ab_total = 0 def _efficient_write_many_data(self, *args, **kwargs): data = args[0] - self.total = len(data) + self._ab_total = len(data) super()._efficient_write_many_data(*args, **kwargs) def _efficient_write_dataset(self, *args, **kwargs): - index = args[0] if import_signals.cancel_sentinel: - log.info(f"\nWriting canceled at position {index}!") + log.info(f"\nWriting canceled at position {self._ab_current_index}!") raise errors.ImportCanceledError - import_signals.db_progress.emit(index + 1, self.total) + self._ab_current_index += 1 + import_signals.db_progress.emit(self._ab_current_index, self._ab_total) return super()._efficient_write_dataset(*args, **kwargs) bd.config.backends["activitybrowser"] = ActivityBrowserBackend +# config is no longer enough to provide an additional backend +# database chooser, specifically looks at DATABASE_BACKEND_MAPPING +# to get the class implementation +DATABASE_BACKEND_MAPPING.update({"activitybrowser": ActivityBrowserBackend}) class ImportSignals(QtCore.QObject): @@ -1316,27 +1512,126 @@ class ImportSignals(QtCore.QObject): import_signals = ImportSignals() -class ABEcoinventDownloader(eidl.EcoinventDownloader): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.extraction_process = None +class ABEcoinventDownloader: + def __init__( + self, + version: typing.Optional[str] = None, + system_model: typing.Optional[str] = None, + release_type: typing.Optional[ecoinvent_interface.ReleaseType] = None, + ): + self.version = version + self.system_model = system_model + self._release_type = release_type + self._settings = ecoinvent_interface.Settings() + self.update_ecoinvent_release() + + def update_ecoinvent_release(self): + try: + self._release = ecoinvent_interface.EcoinventRelease(self._settings) + except ValueError: + self._release = None - def login_success(self, success): - import_signals.login_success.emit(success) + @property + def release(self) -> ecoinvent_interface.EcoinventRelease: + if self._release is None: + raise ValueError("ecoinvent release has not been initialized properly") + return self._release - def extract(self, target_dir): - """Override extract method to redirect the stdout to dev null.""" - code = super().extract(target_dir=target_dir, stdout=subprocess.DEVNULL) - if code != 0: - # The archive was corrupted in some way. - import_signals.cancel_sentinel = True - import_signals.unarchive_failed.emit(self.out_path) - - def handle_connection_timeout(self): - msg = "The request timed out, please check your internet connection!" - if eidl.eidlstorage.stored_dbs: - msg += ( - "\n\nIf you work offline you can use your previously downloaded databases" - + " via the archive option of the import wizard." + @property + def username(self) -> typing.Optional[str]: + return self._settings.username + + @username.setter + def username(self, value: str): + self._settings.username = value + self.update_ecoinvent_release() + + @property + def password(self) -> typing.Optional[str]: + return self._settings.password + + @password.setter + def password(self, value: str): + self._settings.password = value + self.update_ecoinvent_release() + + @property + def release_type(self): + return self._release_type + + @release_type.setter + def release_type(self, value: typing.Union[str, ecoinvent_interface.ReleaseType]): + if isinstance(value, ecoinvent_interface.ReleaseType): + self._release_type = value + return + + if isinstance(value, str): + self._release_type = ecoinvent_interface.ReleaseType[value] + return + + raise ValueError("invalid value provided for release_type") + + def login(self) -> (bool, typing.Optional[typing.Tuple[str, str]]): + release = ecoinvent_interface.EcoinventRelease(self._settings) + error_message = None + try: + release.login() + login_success = True + except ( + requests.ConnectTimeout, + requests.ReadTimeout, + requests.ConnectionError, + ) as e: + login_success = False + error_message = ( + "Connection Problem", + "The request timed out, please check your internet connection!", ) - import_signals.connection_problem.emit(("Connection problem", msg)) + except requests.exceptions.HTTPError as e: + login_success = False + error_message = None + if e.response.status_code != 401: + log.error( + "Unexpected status code (%d) received when trying to list ecoinvent_versions, response: %s", + e.response.status_code, + e.response.text, + ) + error_message = ( + "Unexpected Problem", + "An unexpected error occurred, please try again status code %d" + % e.response.status_code, + ) + + return login_success, error_message + + @lru_cache(maxsize=1) + def list_versions(self): + return self._release.list_versions() + + @lru_cache(maxsize=100) + def list_system_models(self, version: str): + if version == "": + return [] + return self._release.list_system_models(version) + + def download(self) -> Path: + return self.release.get_release( + version=self.version, + system_model=self.system_model, + release_type=self.release_type, + extract=True, + ) + + @staticmethod + def extract(filepath: Path, out_dir: Path = None): + """ + Extract archive + """ + if filepath.suffix.lower() == ".7z": + with py7zr.SevenZipFile(filepath, "r") as archive: + directory = out_dir or (filepath.parent / filepath.stem) + if directory.exists(): + shutil.rmtree(directory) + archive.extractall(path=directory) + else: + raise ValueError("Unsupported archive format")