diff --git a/.gitignore b/.gitignore index e4e5f6c..f5e5e8f 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -*~ \ No newline at end of file +*~ +config.ini +data \ No newline at end of file diff --git a/Makefile b/Makefile index 0cdad25..fe42686 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,13 @@ endif SOURCES=__init__.py \ metadata.txt \ - qdeeplandia.py + qdeeplandia.py \ + config.ini \ + inferenceTask.py \ + feedback.py \ + gui \ + img \ + processing_provider ZIP_FILE=$(PLUGIN_NAME)-$(VERSION).zip diff --git a/config.ini.sample b/config.ini.sample new file mode 100644 index 0000000..b56755e --- /dev/null +++ b/config.ini.sample @@ -0,0 +1,12 @@ +[status] +status = dev + +[running] +processes = 1 + +[symlink] +aerial = /path/to/aerial/dataset/ +tanzania = /path/to/tanzania/dataset/ + +[folder] +project_folder = /path/to/static/files/ diff --git a/feedback.py b/feedback.py new file mode 100644 index 0000000..4eee0d0 --- /dev/null +++ b/feedback.py @@ -0,0 +1,20 @@ +from qgis.core import Qgis, QgsProcessingFeedback, QgsMessageLog + +class Feedback(QgsProcessingFeedback): + """To provide feedback to the message bar from the express tools""" + + def __init__(self, iface): + super().__init__() + self.iface = iface + self.fatal_errors = [] + + def reportError(self, error, fatalError=False): + QgsMessageLog.logMessage(str(error), "QDeeplandia") + if fatalError: + self.fatal_errors.append(error) + + def pushToUser(self, exception): + QgsMessageLog.logMessage(str(exception), "QDeeplandia") + self.iface.messageBar().pushMessage( + "Error", ", ".join(self.fatal_errors), level=Qgis.Critical, duration=0 + ) \ No newline at end of file diff --git a/gui/NbLabelDialog.py b/gui/NbLabelDialog.py new file mode 100644 index 0000000..5633192 --- /dev/null +++ b/gui/NbLabelDialog.py @@ -0,0 +1,23 @@ +from qgis.PyQt.QtWidgets import QDialog, QWidget, QHBoxLayout, QVBoxLayout, QLabel, QSpinBox, QDialogButtonBox + +class NbLabelDialog(QDialog): + def __init__(self,parent): + super(NbLabelDialog, self).__init__() + + self.VL = QVBoxLayout(self) + self.HL = QHBoxLayout() + self.VL.addLayout(self.HL) + + self.label = QLabel(self.tr('Number of label : ')) + self.HL.addWidget(self.label) + + self.spinbox = QSpinBox() + self.HL.addWidget(self.spinbox) + + self.buttonBox = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + self.buttonBox.accepted.connect(self.accept) + self.buttonBox.rejected.connect(self.reject) + self.VL.addWidget(self.buttonBox) + + def param(self): + return self.spinbox.value() \ No newline at end of file diff --git a/gui/__init__.py b/gui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/img/load.svg b/img/load.svg index e69de29..8e90d7b 100644 --- a/img/load.svg +++ b/img/load.svg @@ -0,0 +1,70 @@ + + + + + + + + + + image/svg+xml + + + + + + + + + + diff --git a/img/run.svg b/img/run.svg index e69de29..5993738 100644 --- a/img/run.svg +++ b/img/run.svg @@ -0,0 +1,263 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + Lapo Calamandrei + + + + + Play + + + play + playback + start + begin + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/inferenceTask.py b/inferenceTask.py new file mode 100644 index 0000000..0529c95 --- /dev/null +++ b/inferenceTask.py @@ -0,0 +1,38 @@ +import os +import sys + +from .feedback import Feedback +from qgis.core import Qgis, QgsTask, QgsMessageLog, QgsProcessingContext +import processing +import random, string + +from qgis.PyQt.QtCore import pyqtSignal + +from .processing_provider.inference import InferenceQDeepLandiaProcessingAlgorithm + + +class InferenceTask(QgsTask): + """InferenceTask is a QgsTask subclass""" + + terminated = pyqtSignal(str) + + def __init__(self, description, iface, layer, nb_label, model_path, extent=None): + super().__init__(description, QgsTask.CanCancel) + self.feedback = Feedback(iface) + tmp_name = processing.getTempFilename() + '.tif' + self.param = { 'INPUT' : layer.id(), 'OUTPUT' : os.path.join(tmp_name), 'LABELS' : nb_label, 'MODEL' : model_path } + if extent : + self.param['EXTENT'] = extent + + def run(self): + out = processing.run('QDeepLandia:InferenceQDeepLandia', self.param, feedback=self.feedback) + if os.path.exists(out['OUTPUT']): + self.terminated.emit(out['OUTPUT']) + return True + + def cancel(self): + QgsMessageLog.logMessage( + 'Task "{name}" was canceled'.format( + name=self.description()), "QDeeplandia") + self.terminated.emit(None) + super().cancel() diff --git a/metadata.txt b/metadata.txt index 916aad4..f3d1562 100644 --- a/metadata.txt +++ b/metadata.txt @@ -6,3 +6,4 @@ qgisMinimumVersion=3.00 qgisMaximumVersion=3.99 author=Oslandia email=infos@oslandia.com +hasProcessingProvider=yes \ No newline at end of file diff --git a/processing_provider/__init__.py b/processing_provider/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/processing_provider/datagen.py b/processing_provider/datagen.py new file mode 100644 index 0000000..86b6a52 --- /dev/null +++ b/processing_provider/datagen.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- + +""" +*************************************************************************** +* * +* This program is free software; you can redistribute it and/or modify * +* it under the terms of the GNU General Public License as published by * +* the Free Software Foundation; either version 2 of the License, or * +* (at your option) any later version. * +* * +*************************************************************************** +""" + +import os +import shutil +import subprocess + +from qgis.PyQt.QtCore import QCoreApplication +from qgis.core import (QgsProcessing, + QgsFeatureSink, + QgsProcessingException, + QgsProcessingAlgorithm, + QgsProcessingParameterFolderDestination, + QgsProcessingParameterRasterLayer, + QgsProcessingParameterFile, + QgsProcessingParameterString, + QgsProcessingParameterNumber) +from qgis import processing + + + +class DatagenQDeepLandiaProcessingAlgorithm(QgsProcessingAlgorithm): + """ + """ + + # Constants used to refer to parameters and outputs. They will be + # used when calling the algorithm from another algorithm, or when + # calling from the QGIS console. + + INPUT = 'INPUT' + DATASET = 'DATASET' + SHAPE = 'SHAPE' + OUTPUT = 'OUTPUT' + + def tr(self, string): + """ + Returns a translatable string with the self.tr() function. + """ + return QCoreApplication.translate('Processing', string) + + def createInstance(self): + return DatagenQDeepLandiaProcessingAlgorithm() + + def name(self): + """ + Returns the algorithm name, used for identifying the algorithm. This + string should be fixed for the algorithm, and must not be localised. + The name should be unique within each provider. Names should contain + lowercase alphanumeric characters only and no spaces or other + formatting characters. + """ + return 'DatagenQDeepLandia' + + def displayName(self): + """ + Returns the translated algorithm name, which should be used for any + user-visible display of the algorithm name. + """ + return self.tr('Datageneration') + + def group(self): + """ + Returns the name of the group this algorithm belongs to. This string + should be localised. + """ + return self.tr('QDeepLandia') + + def groupId(self): + """ + Returns the unique ID of the group this algorithm belongs to. This + string should be fixed for the algorithm, and must not be localised. + The group id should be unique within each provider. Group id should + contain lowercase alphanumeric characters only and no spaces or other + formatting characters. + """ + return 'QDeepLandia' + + def shortHelpString(self): + """ + Returns a localised short helper string for the algorithm. This string + should provide a basic description about what the algorithm does and the + parameters and outputs associated with it.. + """ + return self.tr("Preprocess layer into predictable tiles") + + def initAlgorithm(self, config=None): + """ + Here we define the inputs and output of the algorithm, along + with some other properties. + """ + + # We add the input vector features source. It can have any kind of + # geometry. + self.addParameter( + QgsProcessingParameterRasterLayer( + self.INPUT, + self.tr('Input layer') + ) + ) + + # We add a feature sink in which to store our processed features (this + # usually takes the form of a newly created vector layer when the + # algorithm is run in QGIS). + self.addParameter( + QgsProcessingParameterString( + self.DATASET, + self.tr('Dataset name') + ) + ) + + self.addParameter( + QgsProcessingParameterNumber( + self.SHAPE, + self.tr('Number of pixel for the side of tiles'), + type = QgsProcessingParameterNumber.Integer, + defaultValue = 512, + minValue = 16 + ) + ) + + self.addParameter( + QgsProcessingParameterFolderDestination( + self.OUTPUT, + self.tr('Output folder') + ) + ) + + def processAlgorithm(self, parameters, context, feedback): + """ + Here is where the processing itself takes place. + """ + + raster_in = self.parameterAsRasterLayer( + parameters, + self.INPUT, + context + ) + + dest_path = self.parameterAsString( + parameters, + self.OUTPUT, + context + ) + + dataset = self.parameterAsString( + parameters, + self.DATASET, + context + ) + + shape = self.parameterAsInt( + parameters, + self.SHAPE, + context + ) + + path = '' + for i in [dest_path, dataset, 'input', 'testing', 'images']: + path = os.path.join(path, i) + if not os.path.exists(path): + os.mkdir(path) + + for file in os.listdir(path): + os.remove(os.path.join(path,file)) + + shutil.copy( raster_in.source(), os.path.join( path, os.path.basename( raster_in.source()))) + + output_folder = os.path.join(dest_path, dataset, 'preprocessed', str(shape), 'testing', 'images') + shutil.rmtree(os.path.join(dest_path, dataset, 'preprocessed', str(shape))) + + cmd = ['deepo', 'datagen', '-D', dataset, '-s', str(shape), '-P', dest_path, '-T', '1'] + subprocess.run(cmd) + + return {self.OUTPUT: output_folder} diff --git a/processing_provider/inference.py b/processing_provider/inference.py new file mode 100644 index 0000000..0c37a93 --- /dev/null +++ b/processing_provider/inference.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- + +""" +*************************************************************************** +* * +* This program is free software; you can redistribute it and/or modify * +* it under the terms of the GNU General Public License as published by * +* the Free Software Foundation; either version 2 of the License, or * +* (at your option) any later version. * +* * +*************************************************************************** +""" + +import os +import sys +import glob +import gdal +import shutil +import numpy as np + +from qgis.PyQt.QtCore import QCoreApplication +from qgis.core import (QgsProcessing, + QgsRasterLayer, + QgsFeatureSink, + QgsProcessingException, + QgsProcessingAlgorithm, + QgsProcessingParameterFeatureSource, + QgsProcessingParameterFeatureSink, + QgsProcessingParameterRasterLayer, + QgsProcessingParameterFile, + QgsProcessingParameterNumber, + QgsProcessingParameterFileDestination, + QgsProcessingParameterExtent) +from qgis import processing + +from deeposlandia.inference import predict +from deeposlandia.postprocess import get_trained_model, extract_images, \ + extract_coordinates_from_filenames, \ + build_full_labelled_image, get_labels, \ + assign_label_colors, draw_grid + +class InferenceQDeepLandiaProcessingAlgorithm(QgsProcessingAlgorithm): + """ + """ + + # Constants used to refer to parameters and outputs. They will be + # used when calling the algorithm from another algorithm, or when + # calling from the QGIS console. + + INPUT = 'INPUT' + EXTENT = 'EXTENT' + MODEL = 'MODEL' + LABELS = 'LABELS' + OUTPUT = 'OUTPUT' + + def __init__(self, model=None): + super().__init__() + + def tr(self, string): + """ + Returns a translatable string with the self.tr() function. + """ + return QCoreApplication.translate('Processing', string) + + def createInstance(self): + return InferenceQDeepLandiaProcessingAlgorithm() + + def name(self): + """ + Returns the algorithm name, used for identifying the algorithm. This + string should be fixed for the algorithm, and must not be localised. + The name should be unique within each provider. Names should contain + lowercase alphanumeric characters only and no spaces or other + formatting characters. + """ + return 'InferenceQDeepLandia' + + def displayName(self): + """ + Returns the translated algorithm name, which should be used for any + user-visible display of the algorithm name. + """ + return self.tr('Inference') + + def group(self): + """ + Returns the name of the group this algorithm belongs to. This string + should be localised. + """ + return self.tr('QDeepLandia') + + def groupId(self): + """ + Returns the unique ID of the group this algorithm belongs to. This + string should be fixed for the algorithm, and must not be localised. + The group id should be unique within each provider. Group id should + contain lowercase alphanumeric characters only and no spaces or other + formatting characters. + """ + return 'QDeepLandia' + + def shortHelpString(self): + """ + Returns a localised short helper string for the algorithm. This string + should provide a basic description about what the algorithm does and the + parameters and outputs associated with it.. + """ + return self.tr("Do inference according to the loaded model") + + def initAlgorithm(self, config=None): + """ + Here we define the inputs and output of the algorithm, along + with some other properties. + """ + + # We add the input vector features source. It can have any kind of + # geometry. + self.addParameter( + QgsProcessingParameterRasterLayer( + self.INPUT, + self.tr('Input layer') + ) + ) + + self.addParameter( + QgsProcessingParameterExtent( + self.EXTENT, + self.tr('Input extent'), + defaultValue= None, + optional=True + ) + ) + + self.addParameter( + QgsProcessingParameterFile( + self.MODEL, + self.tr('Input model'), + extension="h5" + ) + ) + + self.addParameter( + QgsProcessingParameterNumber( + self.LABELS, + self.tr('Number of labels used for the inference'), + type = QgsProcessingParameterNumber.Integer, + defaultValue = 4, + minValue = 2, + optional = True + ) + ) + + # We add a feature sink in which to store our processed features (this + # usually takes the form of a newly created vector layer when the + # algorithm is run in QGIS). + self.addParameter( + QgsProcessingParameterFileDestination( + self.OUTPUT, + self.tr('Output file') + ) + ) + + def processAlgorithm(self, parameters, context, feedback): + """ + Here is where the processing itself takes place. + """ + + raster_in = self.parameterAsRasterLayer( + parameters, + self.INPUT, + context + ) + + output_path = self.parameterAsString( + parameters, + self.OUTPUT, + context + ) + + model_path = self.parameterAsString( + parameters, + self.MODEL, + context + ) + nb_labels = self.parameterAsInt( + parameters, + self.LABELS, + context + ) + + datapath = os.path.abspath(os.path.join(os.path.dirname(model_path), '..', '..', '..', '..')) + dataset = os.path.basename(os.path.abspath(os.path.join(os.path.dirname(model_path), '..', '..', '..'))) + image_size = os.path.splitext(os.path.basename(model_path))[0].split('-')[-1] + try : + model = get_trained_model(datapath, dataset, int(image_size), int(nb_labels)) + except: + sys.exit() + + extent = self.parameterAsExtent( + parameters, + self.EXTENT, + context + ) + + param = { 'INPUT': raster_in.id(), 'OUTPUT': datapath, 'DATASET': dataset, 'SHAPE': image_size} + if extent.xMinimum() != 0 and extent.xMaximum() != 0: + if ((extent.xMaximum() - extent.xMinimum())/raster_in.rasterUnitsPerPixelX() >= int(image_size) and \ + (extent.yMaximum() - extent.yMinimum())/raster_in.rasterUnitsPerPixelY() >= int(image_size)): + clipped = os.path.join(os.path.dirname(output_path), 'clipped.tif') + param = { 'INPUT': raster_in.id(), 'PROJWIN': extent, 'OUTPUT': clipped} + out = processing.run('gdal:cliprasterbyextent', param, feedback=feedback) + param = { 'INPUT': out['OUTPUT'], 'OUTPUT': datapath, 'DATASET': dataset, 'SHAPE': image_size} + raster_in = QgsRasterLayer(out['OUTPUT'],'clipped', 'gdal') + + out = processing.run('QDeepLandia:DatagenQDeepLandia', param, feedback=feedback) + + raster_list = glob.glob(os.path.join(out['OUTPUT'],'*.png')) + images = extract_images(raster_list) + coordinates = extract_coordinates_from_filenames(raster_list) + labels = get_labels(datapath, dataset, image_size) + + data = build_full_labelled_image( + images, + coordinates, + model, + int(image_size), + int(raster_in.width()), + int(raster_in.height()), + 128 + ) + + colored_data = assign_label_colors(data, labels) + colored_data = draw_grid( + colored_data, int(raster_in.width()), int(raster_in.height()), int(image_size) + ) + predicted_label_folder = os.path.join( + datapath, + dataset, + "output", + "semseg", + "predicted_labels" + ) + os.makedirs(predicted_label_folder, exist_ok=True) + predicted_label_file = os.path.join( + predicted_label_folder, + os.path.basename(os.path.splitext(raster_in.source())[0]) + "_" + str(image_size) + ".tif", + ) + ds = gdal.Open(raster_in.source()) + CreateGeoTiff(predicted_label_file, colored_data, ds.GetGeoTransform(), ds.GetProjection()) + shutil.copy(predicted_label_file, output_path) + return {self.OUTPUT: output_path} + +def CreateGeoTiff(outRaster, data, geo_transform, projection): + driver = gdal.GetDriverByName('GTiff') + rows, cols, no_bands = data.shape + DataSet = driver.Create(outRaster, cols, rows, no_bands, gdal.GDT_Byte) + DataSet.SetGeoTransform(geo_transform) + DataSet.SetProjection(projection) + + data = np.moveaxis(data, -1, 0) + + for i, image in enumerate(data, 1): + DataSet.GetRasterBand(i).WriteArray(image) + DataSet = None \ No newline at end of file diff --git a/processing_provider/provider.py b/processing_provider/provider.py new file mode 100644 index 0000000..ea4986d --- /dev/null +++ b/processing_provider/provider.py @@ -0,0 +1,35 @@ +from qgis.core import QgsProcessingProvider + +from .inference import InferenceQDeepLandiaProcessingAlgorithm +from .datagen import DatagenQDeepLandiaProcessingAlgorithm + + +class QDeepLandiaProvider(QgsProcessingProvider): + + def loadAlgorithms(self, *args, **kwargs): + self.addAlgorithm(InferenceQDeepLandiaProcessingAlgorithm()) + self.addAlgorithm(DatagenQDeepLandiaProcessingAlgorithm()) + # add additional algorithms here + # self.addAlgorithm(MyOtherAlgorithm()) + + def id(self, *args, **kwargs): + """The ID of your plugin, used for identifying the provider. + + This string should be a unique, short, character only string, + eg "qgis" or "gdal". This string should not be localised. + """ + return 'QDeepLandia' + + def name(self, *args, **kwargs): + """The human friendly name of your plugin in Processing. + + This string should be as short as possible (e.g. "Lastools", not + "Lastools version 1.0.1 64-bit") and localised. + """ + return self.tr('QDeepLandia') + + def icon(self): + """Should return a QIcon which is used for your provider inside + the Processing toolbox. + """ + return QgsProcessingProvider.icon(self) \ No newline at end of file diff --git a/qdeeplandia.py b/qdeeplandia.py index 9b06ad2..dcf75d5 100644 --- a/qdeeplandia.py +++ b/qdeeplandia.py @@ -17,44 +17,173 @@ import os +from qgis.core import Qgis, QgsRasterDataProvider, QgsApplication, \ + QgsProcessingFeedback, QgsMessageLog, QgsProcessingContext + +import processing + +from qgis.PyQt.QtCore import QSettings, QCoreApplication, pyqtSignal from qgis.PyQt.QtGui import QIcon -from qgis.PyQt.QtWidgets import QAction +from qgis.PyQt.QtWidgets import QAction, QFileDialog, QWidget, \ + QHBoxLayout, QVBoxLayout, QMessageBox, \ + QToolBar, QLabel, QCheckBox + +os.environ['DEEPOSL_CONFIG'] = os.path.join(os.path.dirname(__file__), 'config.ini') +from deeposlandia.postprocess import get_trained_model + +from .processing_provider.provider import QDeepLandiaProvider +from .gui.NbLabelDialog import NbLabelDialog +from .inferenceTask import InferenceTask -class QDeeplandiaPlugin: +def tr(message): + """Get the translation for a string using Qt translation API. + """ + # noinspection PyTypeChecker,PyArgumentList,PyCallByClass + return QCoreApplication.translate('@default', message) + +class QDeeplandiaPlugin(QWidget): + """ Major class of QDeeplandia plugin """ + + isready = pyqtSignal() + def __init__(self, iface): + """Constructor + + :param iface: qgis interface + :type iface:QgisInterface + """ + super(QDeeplandiaPlugin, self).__init__() self.iface = iface + self.mapCanvas = self.iface.mapCanvas() + self.model = None + self.deepOprovider = None + self.layer = self.updateLayer() + self.nb_labels = None + self.model_path = None + + locale = QSettings().value('locale/userLocale') or 'en_USA' + locale = locale[0:2] + locale_path = os.path.join( + os.path.dirname(__file__), + 'i18n', + 'thyrsis_{}.qm'.format(locale)) + + if os.path.exists(locale_path): + self.translator = QTranslator() + self.translator.load(locale_path, 'qdeeplandia') + QCoreApplication.installTranslator(self.translator) + print("TRANSLATION LOADED", locale_path) + + self.mapCanvas.currentLayerChanged.connect(self.updateLayer) + self.isready.connect(self.ready) def initGui(self): # Select a trained model on the file system - load_model_msg = "Load a trained model" + self.initProcessing() + + self.toolbar = QToolBar(tr("QDeepLandia_toolbar")) + self.toolbar.setObjectName("QDeepLandia_toolbar") + # self.toolbar.setMaximumWidth(180) + self.toolbar.addWidget(QLabel(tr("QDeeplandia"))) + self.iface.addToolBar(self.toolbar) + + # Load model process + load_model_msg = tr("Load a trained model") load_icon = QIcon(os.path.join(os.path.dirname(__file__), "img/load.svg")) self.model_loading = QAction(load_icon, load_model_msg, self.iface.mainWindow()) - self.model_loading.triggered.connect(self.load_trained_model) - self.iface.addPluginToMenu("QDeeplandia", self.model_loading) self.model_loading.triggered.connect(lambda: self.load_trained_model()) - self.iface.addToolBarIcon(self.model_loading) + self.iface.addPluginToMenu("QDeeplandia", self.model_loading) + self.toolbar.addAction(self.model_loading) + # Run-an-inference process - run_inference_msg = "Run an inference" + run_inference_msg = tr("Run an inference") run_icon = QIcon(os.path.join(os.path.dirname(__file__), "img/run.svg")) self.inference = QAction(run_icon, run_inference_msg, self.iface.mainWindow()) - self.inference.triggered.connect(self.infer) - self.iface.addPluginToMenu("QDeeplandia", self.inference) self.inference.triggered.connect(lambda: self.infer()) - self.iface.addToolBarIcon(self.inference) + self.iface.addPluginToMenu("QDeeplandia", self.inference) + self.toolbar.addAction(self.inference) + self.inference.setEnabled(False) + + # Use canvas parameters + self.canvasCheckbox = QCheckBox(tr('Use canvas extent')) + self.toolbar.addWidget(self.canvasCheckbox) + + def initProcessing(self): + self.deepOprovider = QDeepLandiaProvider() + QgsApplication.processingRegistry().addProvider(self.deepOprovider) def unload(self): # Select a trained model on the file system self.iface.removePluginMenu("QDeeplandia", self.model_loading) - self.iface.removeToolBarIcon(self.model_loading) + self.toolbar.setParent(None) self.model_loading.setParent(None) # Run-an-inference process self.iface.removePluginMenu("QDeeplandia", self.inference) - self.iface.removeToolBarIcon(self.inference) self.inference.setParent(None) + QgsApplication.processingRegistry().removeProvider(self.deepOprovider) + + def tr(message): + """Get the translation for a string using Qt translation API. + """ + # noinspection PyTypeChecker,PyArgumentList,PyCallByClass + return QCoreApplication.translate('@default', message) def load_trained_model(self): - pass + """Load a h5 model""" + self.model_path, __ = QFileDialog.getOpenFileName(None, + tr("Load best-model-*.h5 file"), + os.path.abspath("."), + tr("h5 file (*.h5)")) + + if not self.model_path : + return + + nbLabelDlg = NbLabelDialog(self) + + if nbLabelDlg.exec(): + self.nb_labels = nbLabelDlg.param() + else : + return + + self.image_size = os.path.splitext(os.path.basename(self.model_path))[0].split('-')[-1] + try : + self.model = get_trained_model(self.model_path, int(self.image_size), int(self.nb_labels)) + except ValueError as e: + self.iface.messageBar().pushMessage(tr("Critical"), + str(e), level=Qgis.Critical) + + if self.model : + self.updateLayer() def infer(self): - pass + """Launch inference on the current layer""" + extent = None + if self.canvasCheckbox.checkState() : + extent = self.mapCanvas.extent() + + def addOutput(layer): + self.inference.setEnabled(True) + if layer : + self.iface.addRasterLayer(layer) + + task = InferenceTask('Inference', self.iface, self.layer, self.nb_labels, self.model_path, extent) + task.terminated.connect(addOutput) + self.inference.setEnabled(False) + QgsApplication.taskManager().addTask(task) + + def updateLayer(self): + """Update the current layer""" + layer = self.mapCanvas.currentLayer() + if layer : + if isinstance(layer.dataProvider(), QgsRasterDataProvider): + self.layer = layer + else : + self.layer = None + else : + self.layer = None + self.isready.emit() + + def ready(self) : + if self.layer and self.model : + self.inference.setEnabled(True)