diff --git a/.gitignore b/.gitignore
index e4e5f6c..f5e5e8f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,3 @@
-*~
\ No newline at end of file
+*~
+config.ini
+data
\ No newline at end of file
diff --git a/Makefile b/Makefile
index 0cdad25..fe42686 100644
--- a/Makefile
+++ b/Makefile
@@ -8,7 +8,13 @@ endif
SOURCES=__init__.py \
metadata.txt \
- qdeeplandia.py
+ qdeeplandia.py \
+ config.ini \
+ inferenceTask.py \
+ feedback.py \
+ gui \
+ img \
+ processing_provider
ZIP_FILE=$(PLUGIN_NAME)-$(VERSION).zip
diff --git a/config.ini.sample b/config.ini.sample
new file mode 100644
index 0000000..b56755e
--- /dev/null
+++ b/config.ini.sample
@@ -0,0 +1,12 @@
+[status]
+status = dev
+
+[running]
+processes = 1
+
+[symlink]
+aerial = /path/to/aerial/dataset/
+tanzania = /path/to/tanzania/dataset/
+
+[folder]
+project_folder = /path/to/static/files/
diff --git a/feedback.py b/feedback.py
new file mode 100644
index 0000000..4eee0d0
--- /dev/null
+++ b/feedback.py
@@ -0,0 +1,20 @@
+from qgis.core import Qgis, QgsProcessingFeedback, QgsMessageLog
+
+class Feedback(QgsProcessingFeedback):
+ """To provide feedback to the message bar from the express tools"""
+
+ def __init__(self, iface):
+ super().__init__()
+ self.iface = iface
+ self.fatal_errors = []
+
+ def reportError(self, error, fatalError=False):
+ QgsMessageLog.logMessage(str(error), "QDeeplandia")
+ if fatalError:
+ self.fatal_errors.append(error)
+
+ def pushToUser(self, exception):
+ QgsMessageLog.logMessage(str(exception), "QDeeplandia")
+ self.iface.messageBar().pushMessage(
+ "Error", ", ".join(self.fatal_errors), level=Qgis.Critical, duration=0
+ )
\ No newline at end of file
diff --git a/gui/NbLabelDialog.py b/gui/NbLabelDialog.py
new file mode 100644
index 0000000..5633192
--- /dev/null
+++ b/gui/NbLabelDialog.py
@@ -0,0 +1,23 @@
+from qgis.PyQt.QtWidgets import QDialog, QWidget, QHBoxLayout, QVBoxLayout, QLabel, QSpinBox, QDialogButtonBox
+
+class NbLabelDialog(QDialog):
+ def __init__(self,parent):
+ super(NbLabelDialog, self).__init__()
+
+ self.VL = QVBoxLayout(self)
+ self.HL = QHBoxLayout()
+ self.VL.addLayout(self.HL)
+
+ self.label = QLabel(self.tr('Number of label : '))
+ self.HL.addWidget(self.label)
+
+ self.spinbox = QSpinBox()
+ self.HL.addWidget(self.spinbox)
+
+ self.buttonBox = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
+ self.buttonBox.accepted.connect(self.accept)
+ self.buttonBox.rejected.connect(self.reject)
+ self.VL.addWidget(self.buttonBox)
+
+ def param(self):
+ return self.spinbox.value()
\ No newline at end of file
diff --git a/gui/__init__.py b/gui/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/img/load.svg b/img/load.svg
index e69de29..8e90d7b 100644
--- a/img/load.svg
+++ b/img/load.svg
@@ -0,0 +1,70 @@
+
+
+
+
diff --git a/img/run.svg b/img/run.svg
index e69de29..5993738 100644
--- a/img/run.svg
+++ b/img/run.svg
@@ -0,0 +1,263 @@
+
+
+
+
diff --git a/inferenceTask.py b/inferenceTask.py
new file mode 100644
index 0000000..0529c95
--- /dev/null
+++ b/inferenceTask.py
@@ -0,0 +1,38 @@
+import os
+import sys
+
+from .feedback import Feedback
+from qgis.core import Qgis, QgsTask, QgsMessageLog, QgsProcessingContext
+import processing
+import random, string
+
+from qgis.PyQt.QtCore import pyqtSignal
+
+from .processing_provider.inference import InferenceQDeepLandiaProcessingAlgorithm
+
+
+class InferenceTask(QgsTask):
+ """InferenceTask is a QgsTask subclass"""
+
+ terminated = pyqtSignal(str)
+
+ def __init__(self, description, iface, layer, nb_label, model_path, extent=None):
+ super().__init__(description, QgsTask.CanCancel)
+ self.feedback = Feedback(iface)
+ tmp_name = processing.getTempFilename() + '.tif'
+ self.param = { 'INPUT' : layer.id(), 'OUTPUT' : os.path.join(tmp_name), 'LABELS' : nb_label, 'MODEL' : model_path }
+ if extent :
+ self.param['EXTENT'] = extent
+
+ def run(self):
+ out = processing.run('QDeepLandia:InferenceQDeepLandia', self.param, feedback=self.feedback)
+ if os.path.exists(out['OUTPUT']):
+ self.terminated.emit(out['OUTPUT'])
+ return True
+
+ def cancel(self):
+ QgsMessageLog.logMessage(
+ 'Task "{name}" was canceled'.format(
+ name=self.description()), "QDeeplandia")
+ self.terminated.emit(None)
+ super().cancel()
diff --git a/metadata.txt b/metadata.txt
index 916aad4..f3d1562 100644
--- a/metadata.txt
+++ b/metadata.txt
@@ -6,3 +6,4 @@ qgisMinimumVersion=3.00
qgisMaximumVersion=3.99
author=Oslandia
email=infos@oslandia.com
+hasProcessingProvider=yes
\ No newline at end of file
diff --git a/processing_provider/__init__.py b/processing_provider/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/processing_provider/datagen.py b/processing_provider/datagen.py
new file mode 100644
index 0000000..86b6a52
--- /dev/null
+++ b/processing_provider/datagen.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+
+"""
+***************************************************************************
+* *
+* This program is free software; you can redistribute it and/or modify *
+* it under the terms of the GNU General Public License as published by *
+* the Free Software Foundation; either version 2 of the License, or *
+* (at your option) any later version. *
+* *
+***************************************************************************
+"""
+
+import os
+import shutil
+import subprocess
+
+from qgis.PyQt.QtCore import QCoreApplication
+from qgis.core import (QgsProcessing,
+ QgsFeatureSink,
+ QgsProcessingException,
+ QgsProcessingAlgorithm,
+ QgsProcessingParameterFolderDestination,
+ QgsProcessingParameterRasterLayer,
+ QgsProcessingParameterFile,
+ QgsProcessingParameterString,
+ QgsProcessingParameterNumber)
+from qgis import processing
+
+
+
+class DatagenQDeepLandiaProcessingAlgorithm(QgsProcessingAlgorithm):
+ """
+ """
+
+ # Constants used to refer to parameters and outputs. They will be
+ # used when calling the algorithm from another algorithm, or when
+ # calling from the QGIS console.
+
+ INPUT = 'INPUT'
+ DATASET = 'DATASET'
+ SHAPE = 'SHAPE'
+ OUTPUT = 'OUTPUT'
+
+ def tr(self, string):
+ """
+ Returns a translatable string with the self.tr() function.
+ """
+ return QCoreApplication.translate('Processing', string)
+
+ def createInstance(self):
+ return DatagenQDeepLandiaProcessingAlgorithm()
+
+ def name(self):
+ """
+ Returns the algorithm name, used for identifying the algorithm. This
+ string should be fixed for the algorithm, and must not be localised.
+ The name should be unique within each provider. Names should contain
+ lowercase alphanumeric characters only and no spaces or other
+ formatting characters.
+ """
+ return 'DatagenQDeepLandia'
+
+ def displayName(self):
+ """
+ Returns the translated algorithm name, which should be used for any
+ user-visible display of the algorithm name.
+ """
+ return self.tr('Datageneration')
+
+ def group(self):
+ """
+ Returns the name of the group this algorithm belongs to. This string
+ should be localised.
+ """
+ return self.tr('QDeepLandia')
+
+ def groupId(self):
+ """
+ Returns the unique ID of the group this algorithm belongs to. This
+ string should be fixed for the algorithm, and must not be localised.
+ The group id should be unique within each provider. Group id should
+ contain lowercase alphanumeric characters only and no spaces or other
+ formatting characters.
+ """
+ return 'QDeepLandia'
+
+ def shortHelpString(self):
+ """
+ Returns a localised short helper string for the algorithm. This string
+ should provide a basic description about what the algorithm does and the
+ parameters and outputs associated with it..
+ """
+ return self.tr("Preprocess layer into predictable tiles")
+
+ def initAlgorithm(self, config=None):
+ """
+ Here we define the inputs and output of the algorithm, along
+ with some other properties.
+ """
+
+ # We add the input vector features source. It can have any kind of
+ # geometry.
+ self.addParameter(
+ QgsProcessingParameterRasterLayer(
+ self.INPUT,
+ self.tr('Input layer')
+ )
+ )
+
+ # We add a feature sink in which to store our processed features (this
+ # usually takes the form of a newly created vector layer when the
+ # algorithm is run in QGIS).
+ self.addParameter(
+ QgsProcessingParameterString(
+ self.DATASET,
+ self.tr('Dataset name')
+ )
+ )
+
+ self.addParameter(
+ QgsProcessingParameterNumber(
+ self.SHAPE,
+ self.tr('Number of pixel for the side of tiles'),
+ type = QgsProcessingParameterNumber.Integer,
+ defaultValue = 512,
+ minValue = 16
+ )
+ )
+
+ self.addParameter(
+ QgsProcessingParameterFolderDestination(
+ self.OUTPUT,
+ self.tr('Output folder')
+ )
+ )
+
+ def processAlgorithm(self, parameters, context, feedback):
+ """
+ Here is where the processing itself takes place.
+ """
+
+ raster_in = self.parameterAsRasterLayer(
+ parameters,
+ self.INPUT,
+ context
+ )
+
+ dest_path = self.parameterAsString(
+ parameters,
+ self.OUTPUT,
+ context
+ )
+
+ dataset = self.parameterAsString(
+ parameters,
+ self.DATASET,
+ context
+ )
+
+ shape = self.parameterAsInt(
+ parameters,
+ self.SHAPE,
+ context
+ )
+
+ path = ''
+ for i in [dest_path, dataset, 'input', 'testing', 'images']:
+ path = os.path.join(path, i)
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+ for file in os.listdir(path):
+ os.remove(os.path.join(path,file))
+
+ shutil.copy( raster_in.source(), os.path.join( path, os.path.basename( raster_in.source())))
+
+ output_folder = os.path.join(dest_path, dataset, 'preprocessed', str(shape), 'testing', 'images')
+ shutil.rmtree(os.path.join(dest_path, dataset, 'preprocessed', str(shape)))
+
+ cmd = ['deepo', 'datagen', '-D', dataset, '-s', str(shape), '-P', dest_path, '-T', '1']
+ subprocess.run(cmd)
+
+ return {self.OUTPUT: output_folder}
diff --git a/processing_provider/inference.py b/processing_provider/inference.py
new file mode 100644
index 0000000..0c37a93
--- /dev/null
+++ b/processing_provider/inference.py
@@ -0,0 +1,264 @@
+# -*- coding: utf-8 -*-
+
+"""
+***************************************************************************
+* *
+* This program is free software; you can redistribute it and/or modify *
+* it under the terms of the GNU General Public License as published by *
+* the Free Software Foundation; either version 2 of the License, or *
+* (at your option) any later version. *
+* *
+***************************************************************************
+"""
+
+import os
+import sys
+import glob
+import gdal
+import shutil
+import numpy as np
+
+from qgis.PyQt.QtCore import QCoreApplication
+from qgis.core import (QgsProcessing,
+ QgsRasterLayer,
+ QgsFeatureSink,
+ QgsProcessingException,
+ QgsProcessingAlgorithm,
+ QgsProcessingParameterFeatureSource,
+ QgsProcessingParameterFeatureSink,
+ QgsProcessingParameterRasterLayer,
+ QgsProcessingParameterFile,
+ QgsProcessingParameterNumber,
+ QgsProcessingParameterFileDestination,
+ QgsProcessingParameterExtent)
+from qgis import processing
+
+from deeposlandia.inference import predict
+from deeposlandia.postprocess import get_trained_model, extract_images, \
+ extract_coordinates_from_filenames, \
+ build_full_labelled_image, get_labels, \
+ assign_label_colors, draw_grid
+
+class InferenceQDeepLandiaProcessingAlgorithm(QgsProcessingAlgorithm):
+ """
+ """
+
+ # Constants used to refer to parameters and outputs. They will be
+ # used when calling the algorithm from another algorithm, or when
+ # calling from the QGIS console.
+
+ INPUT = 'INPUT'
+ EXTENT = 'EXTENT'
+ MODEL = 'MODEL'
+ LABELS = 'LABELS'
+ OUTPUT = 'OUTPUT'
+
+ def __init__(self, model=None):
+ super().__init__()
+
+ def tr(self, string):
+ """
+ Returns a translatable string with the self.tr() function.
+ """
+ return QCoreApplication.translate('Processing', string)
+
+ def createInstance(self):
+ return InferenceQDeepLandiaProcessingAlgorithm()
+
+ def name(self):
+ """
+ Returns the algorithm name, used for identifying the algorithm. This
+ string should be fixed for the algorithm, and must not be localised.
+ The name should be unique within each provider. Names should contain
+ lowercase alphanumeric characters only and no spaces or other
+ formatting characters.
+ """
+ return 'InferenceQDeepLandia'
+
+ def displayName(self):
+ """
+ Returns the translated algorithm name, which should be used for any
+ user-visible display of the algorithm name.
+ """
+ return self.tr('Inference')
+
+ def group(self):
+ """
+ Returns the name of the group this algorithm belongs to. This string
+ should be localised.
+ """
+ return self.tr('QDeepLandia')
+
+ def groupId(self):
+ """
+ Returns the unique ID of the group this algorithm belongs to. This
+ string should be fixed for the algorithm, and must not be localised.
+ The group id should be unique within each provider. Group id should
+ contain lowercase alphanumeric characters only and no spaces or other
+ formatting characters.
+ """
+ return 'QDeepLandia'
+
+ def shortHelpString(self):
+ """
+ Returns a localised short helper string for the algorithm. This string
+ should provide a basic description about what the algorithm does and the
+ parameters and outputs associated with it..
+ """
+ return self.tr("Do inference according to the loaded model")
+
+ def initAlgorithm(self, config=None):
+ """
+ Here we define the inputs and output of the algorithm, along
+ with some other properties.
+ """
+
+ # We add the input vector features source. It can have any kind of
+ # geometry.
+ self.addParameter(
+ QgsProcessingParameterRasterLayer(
+ self.INPUT,
+ self.tr('Input layer')
+ )
+ )
+
+ self.addParameter(
+ QgsProcessingParameterExtent(
+ self.EXTENT,
+ self.tr('Input extent'),
+ defaultValue= None,
+ optional=True
+ )
+ )
+
+ self.addParameter(
+ QgsProcessingParameterFile(
+ self.MODEL,
+ self.tr('Input model'),
+ extension="h5"
+ )
+ )
+
+ self.addParameter(
+ QgsProcessingParameterNumber(
+ self.LABELS,
+ self.tr('Number of labels used for the inference'),
+ type = QgsProcessingParameterNumber.Integer,
+ defaultValue = 4,
+ minValue = 2,
+ optional = True
+ )
+ )
+
+ # We add a feature sink in which to store our processed features (this
+ # usually takes the form of a newly created vector layer when the
+ # algorithm is run in QGIS).
+ self.addParameter(
+ QgsProcessingParameterFileDestination(
+ self.OUTPUT,
+ self.tr('Output file')
+ )
+ )
+
+ def processAlgorithm(self, parameters, context, feedback):
+ """
+ Here is where the processing itself takes place.
+ """
+
+ raster_in = self.parameterAsRasterLayer(
+ parameters,
+ self.INPUT,
+ context
+ )
+
+ output_path = self.parameterAsString(
+ parameters,
+ self.OUTPUT,
+ context
+ )
+
+ model_path = self.parameterAsString(
+ parameters,
+ self.MODEL,
+ context
+ )
+ nb_labels = self.parameterAsInt(
+ parameters,
+ self.LABELS,
+ context
+ )
+
+ datapath = os.path.abspath(os.path.join(os.path.dirname(model_path), '..', '..', '..', '..'))
+ dataset = os.path.basename(os.path.abspath(os.path.join(os.path.dirname(model_path), '..', '..', '..')))
+ image_size = os.path.splitext(os.path.basename(model_path))[0].split('-')[-1]
+ try :
+ model = get_trained_model(datapath, dataset, int(image_size), int(nb_labels))
+ except:
+ sys.exit()
+
+ extent = self.parameterAsExtent(
+ parameters,
+ self.EXTENT,
+ context
+ )
+
+ param = { 'INPUT': raster_in.id(), 'OUTPUT': datapath, 'DATASET': dataset, 'SHAPE': image_size}
+ if extent.xMinimum() != 0 and extent.xMaximum() != 0:
+ if ((extent.xMaximum() - extent.xMinimum())/raster_in.rasterUnitsPerPixelX() >= int(image_size) and \
+ (extent.yMaximum() - extent.yMinimum())/raster_in.rasterUnitsPerPixelY() >= int(image_size)):
+ clipped = os.path.join(os.path.dirname(output_path), 'clipped.tif')
+ param = { 'INPUT': raster_in.id(), 'PROJWIN': extent, 'OUTPUT': clipped}
+ out = processing.run('gdal:cliprasterbyextent', param, feedback=feedback)
+ param = { 'INPUT': out['OUTPUT'], 'OUTPUT': datapath, 'DATASET': dataset, 'SHAPE': image_size}
+ raster_in = QgsRasterLayer(out['OUTPUT'],'clipped', 'gdal')
+
+ out = processing.run('QDeepLandia:DatagenQDeepLandia', param, feedback=feedback)
+
+ raster_list = glob.glob(os.path.join(out['OUTPUT'],'*.png'))
+ images = extract_images(raster_list)
+ coordinates = extract_coordinates_from_filenames(raster_list)
+ labels = get_labels(datapath, dataset, image_size)
+
+ data = build_full_labelled_image(
+ images,
+ coordinates,
+ model,
+ int(image_size),
+ int(raster_in.width()),
+ int(raster_in.height()),
+ 128
+ )
+
+ colored_data = assign_label_colors(data, labels)
+ colored_data = draw_grid(
+ colored_data, int(raster_in.width()), int(raster_in.height()), int(image_size)
+ )
+ predicted_label_folder = os.path.join(
+ datapath,
+ dataset,
+ "output",
+ "semseg",
+ "predicted_labels"
+ )
+ os.makedirs(predicted_label_folder, exist_ok=True)
+ predicted_label_file = os.path.join(
+ predicted_label_folder,
+ os.path.basename(os.path.splitext(raster_in.source())[0]) + "_" + str(image_size) + ".tif",
+ )
+ ds = gdal.Open(raster_in.source())
+ CreateGeoTiff(predicted_label_file, colored_data, ds.GetGeoTransform(), ds.GetProjection())
+ shutil.copy(predicted_label_file, output_path)
+ return {self.OUTPUT: output_path}
+
+def CreateGeoTiff(outRaster, data, geo_transform, projection):
+ driver = gdal.GetDriverByName('GTiff')
+ rows, cols, no_bands = data.shape
+ DataSet = driver.Create(outRaster, cols, rows, no_bands, gdal.GDT_Byte)
+ DataSet.SetGeoTransform(geo_transform)
+ DataSet.SetProjection(projection)
+
+ data = np.moveaxis(data, -1, 0)
+
+ for i, image in enumerate(data, 1):
+ DataSet.GetRasterBand(i).WriteArray(image)
+ DataSet = None
\ No newline at end of file
diff --git a/processing_provider/provider.py b/processing_provider/provider.py
new file mode 100644
index 0000000..ea4986d
--- /dev/null
+++ b/processing_provider/provider.py
@@ -0,0 +1,35 @@
+from qgis.core import QgsProcessingProvider
+
+from .inference import InferenceQDeepLandiaProcessingAlgorithm
+from .datagen import DatagenQDeepLandiaProcessingAlgorithm
+
+
+class QDeepLandiaProvider(QgsProcessingProvider):
+
+ def loadAlgorithms(self, *args, **kwargs):
+ self.addAlgorithm(InferenceQDeepLandiaProcessingAlgorithm())
+ self.addAlgorithm(DatagenQDeepLandiaProcessingAlgorithm())
+ # add additional algorithms here
+ # self.addAlgorithm(MyOtherAlgorithm())
+
+ def id(self, *args, **kwargs):
+ """The ID of your plugin, used for identifying the provider.
+
+ This string should be a unique, short, character only string,
+ eg "qgis" or "gdal". This string should not be localised.
+ """
+ return 'QDeepLandia'
+
+ def name(self, *args, **kwargs):
+ """The human friendly name of your plugin in Processing.
+
+ This string should be as short as possible (e.g. "Lastools", not
+ "Lastools version 1.0.1 64-bit") and localised.
+ """
+ return self.tr('QDeepLandia')
+
+ def icon(self):
+ """Should return a QIcon which is used for your provider inside
+ the Processing toolbox.
+ """
+ return QgsProcessingProvider.icon(self)
\ No newline at end of file
diff --git a/qdeeplandia.py b/qdeeplandia.py
index 9b06ad2..dcf75d5 100644
--- a/qdeeplandia.py
+++ b/qdeeplandia.py
@@ -17,44 +17,173 @@
import os
+from qgis.core import Qgis, QgsRasterDataProvider, QgsApplication, \
+ QgsProcessingFeedback, QgsMessageLog, QgsProcessingContext
+
+import processing
+
+from qgis.PyQt.QtCore import QSettings, QCoreApplication, pyqtSignal
from qgis.PyQt.QtGui import QIcon
-from qgis.PyQt.QtWidgets import QAction
+from qgis.PyQt.QtWidgets import QAction, QFileDialog, QWidget, \
+ QHBoxLayout, QVBoxLayout, QMessageBox, \
+ QToolBar, QLabel, QCheckBox
+
+os.environ['DEEPOSL_CONFIG'] = os.path.join(os.path.dirname(__file__), 'config.ini')
+from deeposlandia.postprocess import get_trained_model
+
+from .processing_provider.provider import QDeepLandiaProvider
+from .gui.NbLabelDialog import NbLabelDialog
+from .inferenceTask import InferenceTask
-class QDeeplandiaPlugin:
+def tr(message):
+ """Get the translation for a string using Qt translation API.
+ """
+ # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
+ return QCoreApplication.translate('@default', message)
+
+class QDeeplandiaPlugin(QWidget):
+ """ Major class of QDeeplandia plugin """
+
+ isready = pyqtSignal()
+
def __init__(self, iface):
+ """Constructor
+
+ :param iface: qgis interface
+ :type iface:QgisInterface
+ """
+ super(QDeeplandiaPlugin, self).__init__()
self.iface = iface
+ self.mapCanvas = self.iface.mapCanvas()
+ self.model = None
+ self.deepOprovider = None
+ self.layer = self.updateLayer()
+ self.nb_labels = None
+ self.model_path = None
+
+ locale = QSettings().value('locale/userLocale') or 'en_USA'
+ locale = locale[0:2]
+ locale_path = os.path.join(
+ os.path.dirname(__file__),
+ 'i18n',
+ 'thyrsis_{}.qm'.format(locale))
+
+ if os.path.exists(locale_path):
+ self.translator = QTranslator()
+ self.translator.load(locale_path, 'qdeeplandia')
+ QCoreApplication.installTranslator(self.translator)
+ print("TRANSLATION LOADED", locale_path)
+
+ self.mapCanvas.currentLayerChanged.connect(self.updateLayer)
+ self.isready.connect(self.ready)
def initGui(self):
# Select a trained model on the file system
- load_model_msg = "Load a trained model"
+ self.initProcessing()
+
+ self.toolbar = QToolBar(tr("QDeepLandia_toolbar"))
+ self.toolbar.setObjectName("QDeepLandia_toolbar")
+ # self.toolbar.setMaximumWidth(180)
+ self.toolbar.addWidget(QLabel(tr("QDeeplandia")))
+ self.iface.addToolBar(self.toolbar)
+
+ # Load model process
+ load_model_msg = tr("Load a trained model")
load_icon = QIcon(os.path.join(os.path.dirname(__file__), "img/load.svg"))
self.model_loading = QAction(load_icon, load_model_msg, self.iface.mainWindow())
- self.model_loading.triggered.connect(self.load_trained_model)
- self.iface.addPluginToMenu("QDeeplandia", self.model_loading)
self.model_loading.triggered.connect(lambda: self.load_trained_model())
- self.iface.addToolBarIcon(self.model_loading)
+ self.iface.addPluginToMenu("QDeeplandia", self.model_loading)
+ self.toolbar.addAction(self.model_loading)
+
# Run-an-inference process
- run_inference_msg = "Run an inference"
+ run_inference_msg = tr("Run an inference")
run_icon = QIcon(os.path.join(os.path.dirname(__file__), "img/run.svg"))
self.inference = QAction(run_icon, run_inference_msg, self.iface.mainWindow())
- self.inference.triggered.connect(self.infer)
- self.iface.addPluginToMenu("QDeeplandia", self.inference)
self.inference.triggered.connect(lambda: self.infer())
- self.iface.addToolBarIcon(self.inference)
+ self.iface.addPluginToMenu("QDeeplandia", self.inference)
+ self.toolbar.addAction(self.inference)
+ self.inference.setEnabled(False)
+
+ # Use canvas parameters
+ self.canvasCheckbox = QCheckBox(tr('Use canvas extent'))
+ self.toolbar.addWidget(self.canvasCheckbox)
+
+ def initProcessing(self):
+ self.deepOprovider = QDeepLandiaProvider()
+ QgsApplication.processingRegistry().addProvider(self.deepOprovider)
def unload(self):
# Select a trained model on the file system
self.iface.removePluginMenu("QDeeplandia", self.model_loading)
- self.iface.removeToolBarIcon(self.model_loading)
+ self.toolbar.setParent(None)
self.model_loading.setParent(None)
# Run-an-inference process
self.iface.removePluginMenu("QDeeplandia", self.inference)
- self.iface.removeToolBarIcon(self.inference)
self.inference.setParent(None)
+ QgsApplication.processingRegistry().removeProvider(self.deepOprovider)
+
+ def tr(message):
+ """Get the translation for a string using Qt translation API.
+ """
+ # noinspection PyTypeChecker,PyArgumentList,PyCallByClass
+ return QCoreApplication.translate('@default', message)
def load_trained_model(self):
- pass
+ """Load a h5 model"""
+ self.model_path, __ = QFileDialog.getOpenFileName(None,
+ tr("Load best-model-*.h5 file"),
+ os.path.abspath("."),
+ tr("h5 file (*.h5)"))
+
+ if not self.model_path :
+ return
+
+ nbLabelDlg = NbLabelDialog(self)
+
+ if nbLabelDlg.exec():
+ self.nb_labels = nbLabelDlg.param()
+ else :
+ return
+
+ self.image_size = os.path.splitext(os.path.basename(self.model_path))[0].split('-')[-1]
+ try :
+ self.model = get_trained_model(self.model_path, int(self.image_size), int(self.nb_labels))
+ except ValueError as e:
+ self.iface.messageBar().pushMessage(tr("Critical"),
+ str(e), level=Qgis.Critical)
+
+ if self.model :
+ self.updateLayer()
def infer(self):
- pass
+ """Launch inference on the current layer"""
+ extent = None
+ if self.canvasCheckbox.checkState() :
+ extent = self.mapCanvas.extent()
+
+ def addOutput(layer):
+ self.inference.setEnabled(True)
+ if layer :
+ self.iface.addRasterLayer(layer)
+
+ task = InferenceTask('Inference', self.iface, self.layer, self.nb_labels, self.model_path, extent)
+ task.terminated.connect(addOutput)
+ self.inference.setEnabled(False)
+ QgsApplication.taskManager().addTask(task)
+
+ def updateLayer(self):
+ """Update the current layer"""
+ layer = self.mapCanvas.currentLayer()
+ if layer :
+ if isinstance(layer.dataProvider(), QgsRasterDataProvider):
+ self.layer = layer
+ else :
+ self.layer = None
+ else :
+ self.layer = None
+ self.isready.emit()
+
+ def ready(self) :
+ if self.layer and self.model :
+ self.inference.setEnabled(True)