Skip to content

Commit

Permalink
See release notes: evaluations, g1o, oracle, mlflow...
Browse files Browse the repository at this point in the history
  • Loading branch information
jlmeunier committed Aug 7, 2020
1 parent e838055 commit 323e611
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 38 deletions.
9 changes: 9 additions & 0 deletions RELEASE_NOTES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
RELEASE NOTES - TranskribusDU
-----------------------------

--- Vulnéraire - 2020-08-07
- added support to evaluate the segmentation (see --eval_.. options)
- added support for oracle evaluation (--edge_oracle) indicating best achievable quality
- added support to record experiences in MLFLOW (see in options)
- fix for constructing graph when objects overlap each other (--g1o option)
- related to --g1o, better support for constructing object's bounding-box
- few bug and upgrade fix (e.g. in TestReport class, due to scipy evolution, ...)


--- Chrysanthème - 2019-11-21
- ICDAR19 papers are reproducible
- major code reorganisation
Expand Down
215 changes: 190 additions & 25 deletions TranskribusDU/tasks/DU_Task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,21 @@
from common.trace import trace, traceln
from common.chrono import chronoOn, chronoOff, pretty_time_delta
from common.TestReport import TestReportConfusion
import util.Tracking as Tracking
import util.metrics as metrics
from xml_formats.PageXml import MultiPageXml

from graph.GraphModel import GraphModel, GraphModelException, GraphModelNoEdgeException
from graph.Graph import GraphException
from graph.Graph_JsonOCR import Graph_JsonOCR
from graph.Graph_DOM import Graph_DOM
import graph.FeatureDefinition
from tasks import _checkFindColDir

from .DU_Table.DU_Table_Evaluator import eval_cluster_of_files, computePRF

# to activate the use of an oracle for predicting continue or break
bORACLE_EDGE_BREAK_CONTINUE = False

class DU_Task:
"""
Expand Down Expand Up @@ -165,19 +172,29 @@ def __init__(self, sModelName, sModelDir
def getVersion(cls):
return str(cls.VERSION)

def standardDo(self, options):
def standardDo(self, options, experiment_name="DU"):
"""
do whatever is requested by an option from the parsed command line
return None
"""
global bORACLE_EDGE_BREAK_CONTINUE

Tracking.set_no_tracking()

if options.bEdgeOracle:
traceln("*** EDGE ORACLE: predict 'break' or 'continue' label from the groundtruth ***")
bORACLE_EDGE_BREAK_CONTINUE = True

if bool(options.iServer):
assert not bORACLE_EDGE_BREAK_CONTINUE
self.load()
# run in server mode!
self.serve_forever(options.iServer, options.bServerDebug, options=options)
return

if options.rm:
assert not bORACLE_EDGE_BREAK_CONTINUE
self.rm()
return

Expand All @@ -198,6 +215,7 @@ def standardDo(self, options):
#doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern

if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish:
assert not bORACLE_EDGE_BREAK_CONTINUE
if options.iFoldInitNum:
"""
initialization of a cross-validation
Expand All @@ -219,31 +237,97 @@ def standardDo(self, options):


if lFold:
assert not bORACLE_EDGE_BREAK_CONTINUE
loTstRpt = self.nfold_Eval(lFold, 3, .25, None, options.bPkl)
sReportPickleFilename = os.path.join(self.sModelDir, self.sModelName + "__report.txt")
traceln("Results are in %s"%sReportPickleFilename)
GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt)
elif lTrn or lTst or lRun:

if lTrn or lTst or (lRun and (options.bEvalRow or options.bEvalCol or options.bEvalCell or options.sEvalCluster or options.bEvalClusterLevel)):
options.bMLFlow = options.bMLFlow or options.sMLFlowExp # force it!
# ---------- Tracking stuff
if options.sMLFlowURI:
if options.sMLFlowURI == "-" or options.sMLFlowURI.startswith("file"):
# tracking in local files
Tracking.set_tracking()
else:
Tracking.set_tracking_uri(options.sMLFlowURI)
elif options.bMLFlow:
Tracking.set_tracking_uri()
else:
Tracking.set_no_tracking()

# MLFLow Experiment name
_s = options.sMLFlowExp if options.sMLFlowExp else experiment_name
Tracking.set_experiment(_s)
traceln("Tracking experiment = ", _s)
# MLFLow Run name
_s = options.sMLFlowRun if options.sMLFlowRun else self.sModelName
Tracking.start_run(_s)
traceln("Tracking run = ", _s)
if os.environ.get("SLURM_JOB_ID"): Tracking.log_param("SLURM_JOB_ID", os.environ.get("SLURM_JOB_ID"))

Tracking.log_artifact_string("General", json.dumps({
"main" : str(os.path.abspath(sys.argv[0]))
, "main.args" : str(sys.argv[1:])
, "main.graph_class" : self.getGraphClass().__name__
, "main.graph_mode" : self.getGraphClass().getGraphMode()
, "main.ModelDir" : os.path.abspath(self.sModelDir)
, "main.ModelName" : self.sModelName
, "main.model_class" : self.getModelClass().__name__
, "main.seed" : options.seed
, "main.ext" : options.sExt
, 'main.bWarm' : options.warm
}, indent=True)
)
Tracking.log_artifact_string("Options", str(options))
Tracking.log_artifact_string("Options.True", str({k:v for k,v in options.__dict__.items() if bool(v)}))

if lTrn or lTst:
_dCfg = self.getStandardLearnerConfig(options)
# Tracking.log_params(_dCfg)
Tracking.log_artifact_string("LearningParam"
, json.dumps(_dCfg, indent=True))
Tracking.log_artifact_string("Data", json.dumps({'lTrn':lTrn, 'lVld':lVld, 'lTst':lTst, 'ratio_train_val':ratio_train_val}
, indent=True))

if lTrn:
assert not bORACLE_EDGE_BREAK_CONTINUE
tstReport = self.train_save_test(lTrn, lTst, lVld, options.warm, options.bPkl
, ratio_train_val=ratio_train_val)
try: traceln("Baseline best estimator: %s"%self.bsln_mdl.best_params_) #for GridSearch
except: pass
traceln(self.getModel().getModelInfo())
Tracking.log_artifact_string("Model", self.getModel().getModelInfo())
if lTst:
traceln(tstReport)
Tracking.log_artifact_string("test_report", tstReport)
# Return global micro- Precision/Recall/F1, accuracy, support
_p,_r,_f,_a,_s = metrics.confusion_PRFAS(tstReport.getConfusionMatrix())
Tracking.log_metrics({'avgP':_p, 'avgR':_r, 'F1':_f, 'Accuracy':_a}, ndigits=3)
if options.bDetailedReport:
traceln(tstReport.getDetailledReport())
_sTstRpt = tstReport.getDetailledReport()
traceln(_sTstRpt)
Tracking.log_artifact_string("test_report_detailed", _sTstRpt)
elif lTst:
assert not bORACLE_EDGE_BREAK_CONTINUE
self.load()
tstReport = self.test(lTst)
traceln(tstReport)
Tracking.log_artifact_string("test_report", tstReport)
# Return global micro- Precision/Recall/F1, accuracy, support
_p,_r,_f,_a,_s = metrics.confusion_PRFAS(tstReport.getConfusionMatrix())
Tracking.log_metrics({'avgP':_p, 'avgR':_r, 'avgF1':_f, 'Accuracy':_a}, ndigits=3)
# details ...
if options.bDetailedReport:
traceln(tstReport.getDetailledReport())
for test in lTst:
sReportPickleFilename = os.path.join('..',test, self.sModelName + "__report.pkl")
traceln('Report dumped into %s'%sReportPickleFilename)
GraphModel.gzip_cPickle_dump(sReportPickleFilename, tstReport)
_sTstRpt = tstReport.getDetailledReport()
traceln(_sTstRpt)
Tracking.log_artifact_string("test_report_detailed", _sTstRpt)
# for test in lTst:
# sReportPickleFilename = os.path.join('..',test, self.sModelName + "__report.pkl")
# traceln('Report dumped into %s'%sReportPickleFilename)
# GraphModel.gzip_cPickle_dump(sReportPickleFilename, tstReport)

if lRun:
# if options.storeX or options.applyY:
Expand All @@ -253,26 +337,84 @@ def standardDo(self, options):
# else:
self.load()
lsOutputFilename = self.predict(lRun, bGraph=options.bGraph,bOutXML=options.bOutXML)
if options.bEvalRow or options.bEvalCol or options.bEvalCell or options.bEvalRegion or options.sEvalCluster:
if options.sEvalCluster:
sLevel = options.sEvalCluster # only used for display n this case
l = self.getGraphClass().getNodeTypeList()
assert len(l) == 1, "Cannot compute cluster quality with multiple node types"
xpSelector = l[0].getXpathExpr()[0] # node selector, text selector
nOk, nErr, nMiss, sRpt = eval_cluster_of_files(lsOutputFilename
, "cluster"
, bIgnoreHeader=False
, bIgnoreOutOfTable=True
, xpSelector=xpSelector
, sClusterGTAttr=options.sEvalCluster
)
else:
if options.bEvalRow: sLevel = "row"
elif options.bEvalCol: sLevel = "col"
elif options.bEvalCell: sLevel = "cell"
elif options.bEvalRegion: sLevel = "region"
else: raise ValueError()
nOk, nErr, nMiss, sRpt = eval_cluster_of_files(lsOutputFilename
, sLevel
, bIgnoreHeader=False
, bIgnoreOutOfTable=True
)
fP, fR, fF = computePRF(nOk, nErr, nMiss)
traceln(sRpt)
Tracking.log_artifact_string("Cluster_eval", sRpt)
Tracking.log_metrics({ ('P_%s'%sLevel) :fP
, ('R_%s'%sLevel) :fR
, ('F1_%s'%sLevel):fF}, ndigits=2)
elif options.bEvalClusterLevel:
l = self.getGraphClass().getNodeTypeList()
assert len(l) == 1, "Cannot compute cluster quality with multiple node types"
nt = l[0] #unique node type
xpSelector = nt.getXpathExpr()[0] # node selector, text selector
for lvl in range(self.getGraphClass().getHierarchyDepth()):
nOk, nErr, nMiss, sRpt = eval_cluster_of_files(lsOutputFilename
, "cluster_lvl%d"%lvl
, bIgnoreHeader=False
, bIgnoreOutOfTable=False
, xpSelector=xpSelector
, sClusterGTAttr=nt.getLabelAttribute()[lvl]
)
fP, fR, fF = computePRF(nOk, nErr, nMiss)
traceln(sRpt)
Tracking.log_artifact_string("Cluster_eval_lvl%d"%lvl, sRpt)
Tracking.log_metrics({ ('lvl%d_P' %lvl):fP
, ('lvl%d_R' %lvl):fR
, ('lvl%d_F1'%lvl):fF}, ndigits=2)

traceln("Done, see in:\n %s"%lsOutputFilename)
else:
traceln("No action specified in command line. Doing nothing... :)")

Tracking.end_run("FINISHED")
return

def __del__(self):
"""
trying to clean big objects
"""
del self._mdl
del self._lBaselineModel
del self.cFeatureDefinition
del self.cModelClass
try:
del self._mdl
del self._lBaselineModel
del self.cFeatureDefinition
del self.cModelClass
except:
pass
self._mdl = None
self._lBaselineModel = None
self.cFeatureDefinition = None
self.cModelClass = None

#--- SERVER MODE ---------------------------------------------------------
def serve_forever(self, iPort, bDebug=False, options={}):
self.sTime_start = datetime.datetime.now().isoformat()
self.sTime_load = self.sTime_start
self.save_self = self

import socket
sURI = "http://%s:%d" % (socket.gethostbyaddr(socket.gethostname())[0], iPort)
Expand All @@ -287,6 +429,7 @@ def serve_forever(self, iPort, bDebug=False, options={}):
from flask import request, abort
from flask import render_template_string #, render_template
from flask import redirect, url_for #, send_from_directory, send_file
from flask import Response


# Create Flask app load app.config
Expand Down Expand Up @@ -340,7 +483,8 @@ def predict():
if not(isinstance(doc, etree._ElementTree)):
traceln(" converting to PageXml...")
doc = Graph_DOM.exportToDom(lg)
return etree.tostring(doc.getroot(), encoding='UTF-8', xml_declaration=False)
sResp = etree.tostring(doc.getroot(), encoding='UTF-8', xml_declaration=False)
return Response(sResp, mimetype="application/xml")

except Exception as e:
traceln("----- predict exception -------------------------")
Expand All @@ -353,21 +497,26 @@ def reload():
"""
Force to reload the model
"""
self.load(bForce=True)
traceln("Reloading the model")
self.save_self.load(bForce=True)
self.sTime_load = datetime.datetime.now().isoformat()
return redirect(url_for('home_page'))

# RUN THE SERVER !!
# CAUTION: TensorFlow incompatible with debug=True (double load => GPU issue)
app.run(host='0.0.0.0', port=iPort, debug=bDebug)

@app.route('/stop')
def stop():
"""
Force to exit
"""
traceln("Exiting")
sys.exit(0)
traceln("Trying to stop the server...")
func = request.environ.get('werkzeug.server.shutdown')
if func is None:
traceln("Exiting! (but this may stop only one process..)")
sys.exit(0)
else:
traceln("Shutting down")
func()
return redirect(url_for('home_page'))

# RUN THE SERVER !!
# CAUTION: TensorFlow incompatible with debug=True (double load => GPU issue)
app.run(host='0.0.0.0', port=iPort, debug=bDebug)
Expand Down Expand Up @@ -564,7 +713,7 @@ def predict(self, lsColDir, docid=None, bGraph=False, bOutXML=True):
"""
Return the list of produced files
"""
if not self._mdl: raise Exception("The model must be loaded beforehand!")
if not self._mdl and not bORACLE_EDGE_BREAK_CONTINUE: raise Exception("The model must be loaded beforehand!")

#list files
if docid is None:
Expand Down Expand Up @@ -593,8 +742,12 @@ def predict(self, lsColDir, docid=None, bGraph=False, bOutXML=True):
if DU_GraphClass.isOutputFilename(sFilename):
traceln(" - ignoring '%s' because of its extension" % sFilename)
continue

doc, lg = self._predict_file(DU_GraphClass, lPageConstraint, sFilename, bGraph=bGraph)
try:
doc, lg = self._predict_file(DU_GraphClass, lPageConstraint, sFilename, bGraph=bGraph)
except GraphException as e:
doc = None
traceln(str(e))
chronoOff("predict_1") # not nice, I know....

if doc is None:
self.traceln("\t- no prediction to do for: %s"%sFilename)
Expand Down Expand Up @@ -633,7 +786,10 @@ def _predict_file(self, DU_GraphClass, lPageConstraint, sFilename, bGraph=False)
"""
chronoOn("predict_1")
doc = None
lg = DU_GraphClass.loadGraphs(self.cGraphClass, [sFilename], bDetach=False, bLabelled=False, iVerbose=1)
lg = DU_GraphClass.loadGraphs(self.cGraphClass, [sFilename]
, bDetach=False
, bLabelled=bORACLE_EDGE_BREAK_CONTINUE
, iVerbose=1)

#normally, we get one graph per file, but in case we load one graph per page, for instance, we have a list
for i, g in enumerate(lg):
Expand All @@ -655,7 +811,12 @@ def _predict_graph(self, g, lPageConstraint=None, bGraph=False):
return the graph
"""
try:
Y = self._mdl.predict(g, bProba=g.bConjugate)
if bORACLE_EDGE_BREAK_CONTINUE:
Y = np.zeros((len(g.lEdge), 2), dtype=np.float)
for i, e in enumerate(g.lEdge):
Y[i, 1-int(e.A.cls == e.B.cls)] = 1.0
else:
Y = self._mdl.predict(g, bProba=g.bConjugate)
g.setDocLabels(Y)
if bGraph and not Y is None:
if g.bConjugate:
Expand Down Expand Up @@ -930,7 +1091,11 @@ def _train_save_test(self, sModelName, bWarm, lFilename_trn, ts_trn, lFilename_t
try:
mdl.loadTransformers(ts_trn)
except GraphModelException:
fe = self.cFeatureDefinition(**self.config_extractor_kwargs)
try:
fe = self.cFeatureDefinition(**self.config_extractor_kwargs)
except Exception as e:
traceln("ERROR: could not instantiate feature definition class: ", str(self.cFeatureDefinition))
raise e
fe.fitTranformers(lGraph_trn)
fe.cleanTransformers()
mdl.setTranformers(fe.getTransformers())
Expand Down
Loading

0 comments on commit 323e611

Please sign in to comment.