Skip to content

Commit

Permalink
update to July24 (Edelweiss)
Browse files Browse the repository at this point in the history
  • Loading branch information
DRRV committed Nov 18, 2019
1 parent 5328af9 commit 3a3e179
Showing 1 changed file with 71 additions and 99 deletions.
170 changes: 71 additions & 99 deletions TranskribusDU/tasks/DU_StAZH_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
First DU task for StAZH
Copyright Xerox(C) 2016 JL. Meunier
Copyright Naver (C) 2019 H. Déjean
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Expand All @@ -33,116 +30,91 @@
import TranskribusDU_version

from common.trace import traceln
from tasks import _checkFindColDir, _exit

from crf.Graph_MultiPageXml import Graph_MultiPageXml
from crf.NodeType_PageXml import NodeType_PageXml
from DU_CRF_Task import DU_CRF_Task


# ===============================================================================================================
#DEFINING THE CLASS OF GRAPH WE USE
DU_GRAPH = Graph_MultiPageXml
nt = NodeType_PageXml("TR" #some short prefix because labels below are prefixed with it
, ['catch-word', 'header', 'heading', 'marginalia', 'page-number'] #EXACTLY as in GT data!!!!
, [] #no ignored label/ One of those above or nothing, otherwise Exception!!
, True #no label means OTHER
)
nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes
, "./pc:TextEquiv") #how to get their text
)
DU_GRAPH.addNodeType(nt)
# ===============================================================================================================


class DU_StAZH_a(DU_CRF_Task):
from tasks.DU_Task_Factory import DU_Task_Factory
from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml
from graph.NodeType_PageXml import NodeType_PageXml_type
from graph.FeatureDefinition_PageXml_std import FeatureDefinition_PageXml_StandardOnes
from graph.NodeType_PageXml import NodeType_PageXml_type_woText, NodeType_PageXml_type
from graph.FeatureDefinition_PageXml_std_noText_v4 import FeatureDefinition_PageXml_StandardOnes_noText_v4


def getConfiguredGraphClass(doer):
"""
We will do a CRF model for a DU task
, working on a MultiPageXMl document at TextRegion level
, with the below labels
In this class method, we must return a configured graph class
"""

#=== CONFIGURATION ====================================================================
def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None):

DU_CRF_Task.__init__(self
, sModelName, sModelDir
, DU_GRAPH
, dFeatureConfig = {
'n_tfidf_node' : 500
, 't_ngrams_node' : (2,4)
, 'b_tfidf_node_lc' : False
, 'n_tfidf_edge' : 250
, 't_ngrams_edge' : (2,4)
, 'b_tfidf_edge_lc' : False
}
# , dLearnerConfig = {
# 'C' : .1
# # 'C' : 1.0
# , 'njobs' : 4
# , 'inference_cache' : 50
# , 'tol' : .1
# # , 'tol' : 0.05
# , 'save_every' : 50 #save every 50 iterations,for warm start
# , 'max_iter' : 250
# }
# }
, dLearnerConfig = {
'C' : .1 if C is None else C
, 'njobs' : 5 if njobs is None else njobs
, 'inference_cache' : 50 if inference_cache is None else inference_cache
#, 'tol' : .1
, 'tol' : .05 if tol is None else tol
, 'save_every' : 50 #save every 50 iterations,for warm start
, 'max_iter' : 1000 if max_iter is None else max_iter
}
, sComment=sComment
)
#deprecated self.setNbClass(5+1)
self.addBaseline_LogisticRegression() #use a LR model as baseline
#=== END OF CONFIGURATION =============================================================
#DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml # consider each age as if indep from each other
DU_GRAPH = Graph_MultiSinglePageXml

ntClass = NodeType_PageXml_type_woText #NodeType_PageXml_type

#lIgnoredLabels = ['menu-section-heading','Item-number']

lLabels = ['catch-word', 'header', 'heading', 'marginalia', 'page-number']

nt = ntClass("TR" #some short prefix because labels below are prefixed with it
, lLabels # in conjugate, we accept all labels, andNone becomes "none"
, []
, True # unused
, BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way
)
nt.setLabelAttribute("type")
#DU_GRAPH.sEdgeLabelAttribute="TR"
nt.setXpathExpr((".//pc:TextRegion"
, ".//pc:TextEquiv") #how to get their text
)
DU_GRAPH.addNodeType(nt)

return DU_GRAPH


if __name__ == "__main__":
# import better_exceptions
# better_exceptions.MAX_LENGTH = None

# standard command line options for CRF- ECN- GAT-based methods
usage, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0])

version = "v.01"
usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version)
traceln("VERSION: %s" % DU_Task_Factory.getVersion())

# ---
#parse the command line
(options, args) = parser.parse_args()
# ---

cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText_v4
# dFeatureConfig = {
# 'n_tfidf_node':400, 't_ngrams_node':(1,3), 'b_tfidf_node_lc':False
# , 'n_tfidf_edge':400, 't_ngrams_edge':(1,3), 'b_tfidf_edge_lc':False }
try:
sModelDir, sModelName = args
except Exception as e:
_exit(usage, 1, e)

doer = DU_StAZH_a(sModelName, sModelDir,
C = options.crf_C,
tol = options.crf_tol,
njobs = options.crf_njobs,
max_iter = options.max_iter,
inference_cache = options.crf_inference_cache)
traceln("Specify a model folder and a model name!")
DU_Task_Factory.exit(usage, 1, e)

doer = DU_Task_Factory.getDoer(sModelDir, sModelName
, options = options
, fun_getConfiguredGraphClass= getConfiguredGraphClass
, cFeatureDefinition = cFeatureDefinition
, dFeatureConfig = {}
)

if options.rm:
doer.rm()
sys.exit(0)
# setting the learner configuration, in a standard way
# (from command line options, or from a JSON configuration file)
dLearnerConfig = doer.getStandardLearnerConfig(options)

traceln("- classes: ", DU_GRAPH.getLabelNameList())

# of course, you can put yours here instead.
doer.setLearnerConfiguration(dLearnerConfig)

#doer.setConjugateMode()

#Add the "col" subdir if needed
lTrn, lTst, lRun = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun]]

if lTrn:
doer.train_save_test(lTrn, lTst, options.warm)
elif lTst:
doer.load()
tstReport = doer.test(lTst)
traceln(tstReport)
# act as per specified in the command line (--trn , --fold-run, ...)
doer.standardDo(options)

if lRun:
doer.load()
lsOutputFilename = doer.predict(lRun)
traceln("Done, see in:\n %s"%lsOutputFilename)
del doer







0 comments on commit 3a3e179

Please sign in to comment.