diff --git a/.gitignore b/.gitignore index b7099d7..f60126f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ .cache/ .settings/ __pycache__/ +*.bak diff --git a/LICENSE b/LICENSE index 42a0ba9..bd95bf6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ BSD 3-Clause License -Copyright (c) 2016, Transkribus +Copyright (c) 2016-2019, NAVER LABS Europe + All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index 39dfec0..cbadf75 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # TranskribusDU Document Understanding tools -### Requirements, installation & testing +Updated: 2019-11-20 + +### Requirements, installation #### Python @@ -9,9 +11,14 @@ Document Understanding tools We recommend installing __anaconda3__ . You can then train using pystruct and/or tensorflow (both to be installed on top of anaconda). - * conda install shapely rtree + * conda install shapely rtree lxml scipy + * pip install future scikit-learn pytest --upgrade + +To learn with pystruct (using a graph-CRF model): + * pip install cvxopt ad3 pystruct --upgrade + +To learn with Tensorflow (using an Edge Convolutional Network): * conda install -c anaconda tensorflow(-gpu) - * pip install future lxml scipy scikit-learn pytest cvxopt ad3 pystruct --upgrade ### Usage * see use-cases diff --git a/RELEASE_NOTES.txt b/RELEASE_NOTES.txt index 9632ce4..280d8ca 100644 --- a/RELEASE_NOTES.txt +++ b/RELEASE_NOTES.txt @@ -2,11 +2,36 @@ RELEASE NOTES - TranskribusDU ----------------------------- ---- Edelweiss - 2019-07-22 -- ECN/GAT -- conjugate -- --g1 --g2 -- table understanding +--- Chrysanthème - 2019-11-21 +- ICDAR19 papers are reproducible +- major code reorganisation +- Multipage XML bug fixes +- standard projection profile method +- convex hull for cluster Coords +- ECN ensemble bug fix +- various bug fixes +- --server mode +- segmentation task using agglomerative clustering +- Json input +- pipe example +- table reconstruction +- generic features (when no page info) +- edge features reworked +- cluster evaluation metrics + + +--- Iris - 2019-04-25 +- CRF, ECN GAT supported +- conjugate mode supported +- --vld option to specify a validation set, or a ratio of validation graphs + taken from the training set. The best model on validation set is kept. +- --graph option to store the edges in the output XML +- --max_iter applies to all learning methods +- --seed to seed the randomizer with a constant +- dynamic load of the learners +- major code re-organization +- for example of use, see in tasks: DU_TABLE_BIO.py or DU_Table_Row_Edge.py + --- Jonquille - 2017-04-28 - multi-type classification supported @@ -30,21 +55,10 @@ --------------------------------------------------------------------------------- - Copyright (C) 2016, 2017 H. Déjean, JL Meunier - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. + Copyright (C) 2016-2019 H. Déjean, JL Meunier - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme - under grant agreement No 674943. + under grant agreement No 674943. \ No newline at end of file diff --git a/TranskribusDU/ObjectModel/XMLDSBASELINEClass.py b/TranskribusDU/ObjectModel/XMLDSBASELINEClass.py index a2736f8..07a2367 100644 --- a/TranskribusDU/ObjectModel/XMLDSBASELINEClass.py +++ b/TranskribusDU/ObjectModel/XMLDSBASELINEClass.py @@ -45,12 +45,10 @@ def computePoints(self): if self.lPoints is None: self.lPoints = self.getAttribute('blpoints') # print 'after split?',self.lPoints - - self.lPoints = self.lPoints.replace(" ",",") if self.lPoints is not None: lX = list(map(lambda x:float(x),self.lPoints.split(',')))[0::2] lY = list(map(lambda x:float(x),self.lPoints.split(',')))[1::2] - self.lPoints = list(zip(lX,lY)) + self.lPoints = zip(lX,lY) # lY.sort() # if len(lY)> 10: ## if basline automatically generated: beg and end noisy # lY= lY[1:-2] @@ -67,9 +65,14 @@ def computePoints(self): import numpy as np a,b = np.polyfit(lX, lY, 1) self.setAngle(a) - self.setBx(b) # ymax = a * self.getX2() +b # ymin = a*self.getX() + b +# import libxml2 +# verticalSep = libxml2.newNode('PAGEBORDER') +# verticalSep.setProp('points', '%f,%f,%f,%f'%(self.getX(),ymin,self.getX2(),ymax)) +# # print 'p',self.getParent() +# # print 'pp',self.getParent().getParent() +# self.getParent().getNode().addChild(verticalSep) """ TO simulate 'DS' objects diff --git a/TranskribusDU/ObjectModel/XMLDSCELLClass.py b/TranskribusDU/ObjectModel/XMLDSCELLClass.py index 6e90014..117ba5d 100644 --- a/TranskribusDU/ObjectModel/XMLDSCELLClass.py +++ b/TranskribusDU/ObjectModel/XMLDSCELLClass.py @@ -8,18 +8,7 @@ a class for table cell from a XMLDocument READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -96,17 +85,7 @@ def fromDom(self,domNode): self.setNode(domNode) # get properties for prop in domNode.keys(): - try: - self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) - except: - self._x=-1 - self._y=-1 - self._h=0 - self._w=0 + self.addAttribute(prop,domNode.get(prop)) self.setIndex(int(self.getAttribute('row')),int(self.getAttribute('col'))) diff --git a/TranskribusDU/ObjectModel/XMLDSGRAHPLINEClass.py b/TranskribusDU/ObjectModel/XMLDSGRAHPLINEClass.py index 865eb3c..4b3736b 100644 --- a/TranskribusDU/ObjectModel/XMLDSGRAHPLINEClass.py +++ b/TranskribusDU/ObjectModel/XMLDSGRAHPLINEClass.py @@ -40,10 +40,9 @@ def computePoints(self): self.lPoints = self.getAttribute('points') # print 'after split?',self.lPoints if self.lPoints is not None: - lX=[float(x) for p in self.lPoints.split(' ') for x in p.split(',')[0::2]] -# lX = list(map(lambda x:float(x),self.lPoints.split(',')))[0::2] - lY = [float(x) for p in self.lPoints.split(' ') for x in p.split(',')[1::2]] - self.lPoints = list(zip(lX,lY)) + lX = list(map(lambda x:float(x),self.lPoints.split(',')))[0::2] + lY = list(map(lambda x:float(x),self.lPoints.split(',')))[1::2] + self.lPoints = zip(lX,lY) try: self.avgY = 1.0 * sum(lY)/len(lY) except ZeroDivisionError: @@ -59,6 +58,10 @@ def computePoints(self): # self.setAngle(a) # ymax = a * self.getX2() +b # ymin = a*self.getX() + b +# import libxml2 +# verticalSep = libxml2.newNode('PAGEBORDER') +# verticalSep.setProp('points', '%f,%f,%f,%f'%(self.getX(),ymin,self.getX2(),ymax)) +# self.getParent().getNode().addChild(verticalSep) """ TO simulate 'DS' objects @@ -76,7 +79,6 @@ def getWidth(self): return abs(float(self.getAttribute('width'))) def setPoints(self,lp): self.lPoints = lp - def getPoints(self): return self.lPoints def fromDom(self,domNode): """ diff --git a/TranskribusDU/ObjectModel/XMLDSLINEClass.py b/TranskribusDU/ObjectModel/XMLDSLINEClass.py index bef4c52..f1f828e 100644 --- a/TranskribusDU/ObjectModel/XMLDSLINEClass.py +++ b/TranskribusDU/ObjectModel/XMLDSLINEClass.py @@ -36,16 +36,13 @@ def fromDom(self,domNode): self.setNode(domNode) # get properties -# for prop in domNode.keys(): -# self.addAttribute(prop,domNode.get(prop)) - for prop in domNode.keys(): self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) - + +# ctxt = domNode.doc.xpathNewContext() +# ctxt.setContextNode(domNode) +# ldomElts = ctxt.xpathEval('./%s'%(ds_xml.sTEXT)) +# ctxt.xpathFreeContext() ldomElts = domNode.findall('./%s'%(ds_xml.sTEXT)) for elt in ldomElts: myObject= XMLDSTEXTClass(elt) diff --git a/TranskribusDU/ObjectModel/XMLDSObjectClass.py b/TranskribusDU/ObjectModel/XMLDSObjectClass.py index dc03077..61ec8d0 100644 --- a/TranskribusDU/ObjectModel/XMLDSObjectClass.py +++ b/TranskribusDU/ObjectModel/XMLDSObjectClass.py @@ -16,7 +16,7 @@ from .XMLObjectClass import XMLObjectClass from config import ds_xml_def as ds_xml -from shapely.geometry import polygon,Polygon,LineString + from lxml import etree class XMLDSObjectClass(XMLObjectClass): @@ -34,18 +34,7 @@ def __init__(self): self._id= None self.Xnearest=[[],[]] # top bottom - self.Ynearest=[[],[]] # left right - - - ## need to have x,y,w,h directly: otherwise: too slow - self._x = None - self._y = None - self._h = None - self._w = None - - self._poly = None - - + self.Ynearest=[[],[]] # left right def getPage(self): return self._page def setPage(self,p): self._page= p @@ -56,38 +45,23 @@ def getID(self): return self._id # def addElement(self,e): self._lElements.append(e) - def getX(self): return self._x #float(self.getAttribute('x')) - def getY(self): return self._y # return float(self.getAttribute('y')) - def getX2(self): return self.getX() + self.getWidth() - def getY2(self): return self.getY() +self.getHeight() - def getHeight(self): return self._h #return float(self.getAttribute('height')) - def getWidth(self): return self._w #return float(self.getAttribute('width')) + def getX(self): return float(self.getAttribute('x')) + def getY(self): return float(self.getAttribute('y')) + def getX2(self): return float(self.getAttribute('x'))+self.getWidth() + def getY2(self): return float(self.getAttribute('y'))+self.getHeight() + def getHeight(self): return float(self.getAttribute('height')) + def getWidth(self): return float(self.getAttribute('width')) - def setX(self,x): self.addAttribute('x',x); self._x = float(x) - def setY(self,y): self.addAttribute('y',y); self._y = float(y) - def setWidth(self,w): self.addAttribute('width',w);self._w = float(w) - def setHeight(self,h):self.addAttribute('height',h);self._h = float(h) + def setX(self,x): self.addAttribute('x',x) + def setY(self,y): self.addAttribute('y',y) + def setWidth(self,w): self.addAttribute('width',w) + def setHeight(self,h):self.addAttribute('height',h) def setDimensions(self,x,y,h,w): - self.setX(x) - self.setY(y) - self.setHeight(h) - self.setWidth( w) - - def toPolygon(self): - """ - return a shapely polygon using points!!! - points="375.12,98.88,924.0,101.52,924.0,113.52,375.12,110.88" - """ - if self._poly is not None: - return self._poly - - x = [float(x) for x in self.getAttribute("points").replace(" ",",").split(',')] - if len(x) <3*2: - return LineString(list(zip(*[iter(x)]*2))) - self._poly = polygon.orient(Polygon(list(zip(*[iter(x)]*2)))) - if not self._poly.is_valid:self._poly= self._poly.convex_hull - return self._poly + self.addAttribute('x', x) + self.addAttribute('y', y) + self.addAttribute('height', h) + self.addAttribute('width', w) def addObject(self,o,bDom=False): ## move dom node as well @@ -97,23 +71,14 @@ def addObject(self,o,bDom=False): o.setParent(self) if bDom: if o.getNode() is not None and self.getNode() is not None: - o.getNode().getparent().remove(o.getNode()) - self.getNode().append(o.getNode()) + o.getNode().unlinkNode() + self.getNode().addChild(o.getNode()) - def removeObject(self,o,bDom=False): - """ - remove o from self.getObjects() - unlink if bDom - """ - self.getObjects().remove(o) - if bDom: - if o.getNode() is not None and self.getNode() is not None: - o.getNode().getparent().remove(o.getNode()) + def resizeMe(self,objectType): - assert len(self.getAllNamedObjects(objectType)) != 0 minbx = 9e9 @@ -127,28 +92,18 @@ def resizeMe(self,objectType): if elt.getX() + elt.getWidth() > maxbx: maxbx = elt.getX() + elt.getWidth() if elt.getY() + elt.getHeight() > maxby: maxby = elt.getY() + elt.getHeight() assert minby != 9e9 - self.setX( minbx) - self.setY( minby) - self.setWidth(maxbx-minbx) - self.setHeight(maxby-minby) + self.addAttribute('x', minbx) + self.addAttribute('y', minby) + self.addAttribute('width', maxbx-minbx) + self.addAttribute('height', maxby-minby) self.addAttribute('x2', maxbx) self.addAttribute('y2', maxby) + self._BB = [minbx,minby,maxby-minby,maxbx-minbx] - def setXYHW(self,x,y,h,w): - self.setX(x) - self.setY(y) - self.setHeight(h) - self.setWidth(w) - def copyXYHW(self,o): - self._x = o.getX() - self._y =o.getY() - self._h = o.getHeight() - self._w = o.getWidth() - def fromDom(self,domNode): ## if domNode in mappingTable: @@ -164,13 +119,10 @@ def fromDom(self,domNode): # get properties for prop in domNode.keys(): self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) - self.addAttribute('x2', self.getX()+self.getWidth()) - self.addAttribute('y2',self.getY()+self.getHeight() ) + self.addAttribute('x2', float(self.getAttribute('x'))+self.getWidth()) + self.addAttribute('y2',float(self.getAttribute('y'))+self.getHeight() ) + self._id = self.getAttribute('id') if self.getID() is None: @@ -217,62 +169,20 @@ def fromDom(self,domNode): self.addObject(myObject) myObject.setPage(self.getPage()) myObject.fromDom(child) + - - - def bestRegionsAssignment(self,lRegions,bOnlyBaseline=False): - """ - find the best (max overlap for self) region for self - bOnlyBaseline: reduce the height so that baseline position is more important - """ - from rtree import index - - assert self.toPolygon().convex_hull.is_valid - - txtidx = index.Index() - lP = [] - [lP.append(e.toPolygon()) for e in lRegions if e.toPolygon().is_valid] - for i,elt in enumerate(lRegions): - txtidx.insert(i, lP[i].bounds) - lSet = txtidx.intersection(self.toPolygon().bounds) - lOverlap = [] - for ei in lSet: - if lP[ei].is_valid: - intersec= self.toPolygon().intersection(lP[ei]).area - if intersec >0: - lOverlap.append((ei,lP[ei],intersec)) - if lOverlap != []: - lOverlap.sort(key=lambda xyz:xyz[-1]) -# print ("??",self,lRegions[lOverlap[-1][0]]) - return lRegions[lOverlap[-1][0]] - return None - - def bestRegionsAssignmentOld(self,lRegions,bOnlyBaseline=False): + def bestRegionsAssignment(self,lRegions): """ find the best (max overlap for self) region for self - bOnlyBaseline: reduce the height so that baseline position is more important """ - if bOnlyBaseline: - #backup height - Hbackup = self.getHeight() - Ybackup= self.getY() - self.setHeight(1) - self.setY(Hbackup+self.getY()) lOverlap=[] for region in lRegions: -# lOverlap.append(self.signedRatioOverlap(region)) - lOverlap.append(self.signedRatioOverlapY(region)) -# print(self.getX(),self.getWidth(),region, self.signedRatioOverlapX(region)) - - if bOnlyBaseline: - #restaure height - self.setHeight(Hbackup) - self.setY(Ybackup) + lOverlap.append(self.signedRatioOverlap(region)) - if lOverlap ==[] : return None + if max(lOverlap) == 0: return None return lRegions[lOverlap.index(max(lOverlap))] def clipMe(self,clipRegion,lSubObject=[]): @@ -299,10 +209,10 @@ def clipMe(self,clipRegion,lSubObject=[]): newW = min(self.getX2(),clipRegion.getX2()) - newX newH = min(self.getY2(),clipRegion.getY2()) - newY - myNewObject.setX(newX) - myNewObject.setY(newY) - myNewObject.setHeight(newH) - myNewObject.setWidth(newW) + myNewObject.addAttribute('x',newX) + myNewObject.addAttribute('y',newY) + myNewObject.addAttribute('height',newH) + myNewObject.addAttribute('width',newW) # print self.getID(),self.getName(),self.getContent() # print '\tnew dimensions',myNewObject.getX(),myNewObject.getY(),myNewObject.getWidth(),myNewObject.getHeight() @@ -324,21 +234,7 @@ def clipMe(self,clipRegion,lSubObject=[]): return None - def signedRatioOverlapY(self,zone): - """ - return the overlap ratio betwenn self and zone - """ - [a1,a2] = self.getY(),self.getY() + self.getHeight() - [b1,b2] = zone.getY(),zone.getY() + zone.getHeight() - if min(a2, b2) >= max(a1, b1): return min(a2, b2) - max(a1, b1) - else: return -1 - - def signedRatioOverlapX(self,zone): - - [a1,a2] = self.getX(),self.getX()+ self.getWidth() - [b1,b2] = zone.getX(),zone.getX()+ zone.getWidth() - if min(a2, b2) >= max(a1, b1): return min(a2, b2) - max(a1, b1) - else: return -1 + def signedRatioOverlap(self,zone): """ @@ -426,30 +322,6 @@ def overlapY(self,zone): [b1,b2] = zone.getY(),zone.getY() + zone.getHeight() return min(a2, b2) >= max(a1, b1) - - def getSetOfX1X2Attributes(self,TH,foo,myObject): - """ - input: feature threshold (eq) - """ - from spm.feature import featureObject,TwoDFeature - - if self._lBasicFeatures is None: - self._lBasicFeatures = [] - # needed to keep canonical values! - elif self.getSetofFeatures() != []: - return self.getSetofFeatures() - - for elt in self.getAllNamedObjects(myObject): - ftype= featureObject.COMPLEX - feature = TwoDFeature() - feature.setName("x1x2") - feature.setTH(TH) - feature.addNode(elt) - feature.setObjectName(self) - feature.setValue((elt.getX(),elt.getX2())) - feature.setType(ftype) - self.addFeature(feature) - def getSetOfListedAttributes(self,TH,lAttributes,myObject): """ diff --git a/TranskribusDU/ObjectModel/XMLDSPageClass.py b/TranskribusDU/ObjectModel/XMLDSPageClass.py index 9f3b73a..488921f 100644 --- a/TranskribusDU/ObjectModel/XMLDSPageClass.py +++ b/TranskribusDU/ObjectModel/XMLDSPageClass.py @@ -22,7 +22,6 @@ from .XMLDSGRAHPLINEClass import XMLDSGRAPHLINEClass from .XMLDSTABLEClass import XMLDSTABLEClass from .XMLDSCOLUMN import XMLDSCOLUMNClass -from .XMLDSRelationClass import XMLDSRelationClass class XMLDSPageClass(XMLDSObjectClass): """ @@ -52,10 +51,7 @@ def __init__(self,domNode = None): self._X1X2 = [] self.lf_XCut = [] - - # horizontal and vertical chunks built with element 'key' - self._dHC = {} - self._dVC = {} + ## list of vertical page zones Templates # by default: one column Template @@ -102,21 +98,17 @@ def fromDom(self,domNode,lEltNames): # get properties for prop in domNode.keys(): self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) + +# ctxt = domNode.doc.xpathNewContext() +# ctxt.setContextNode(domNode) +# ldomElts = ctxt.xpathEval('./*') +# ctxt.xpathFreeContext() ldomElts = domNode.findall('./*') for elt in ldomElts: ### GT elt if elt.tag =='MARGIN': elt = list(elt)[0] #elt=elt.children - if elt.tag == 'EDGE': - myRel= XMLDSRelationClass() - self.addRelation(myRel) - myRel.fromDom(elt) if elt.tag in lEltNames: - #relationObject if elt.tag == ds_xml_def.sCOL_Elt: myObject= XMLDSCOLUMNClass(elt) self.addObject(myObject) @@ -158,38 +150,7 @@ def fromDom(self,domNode,lEltNames): else: pass - def reifyRelations(self): - """ - what was read: ids: now replace ids by objects - - """ - self.relmat={} - self.relmatH={} - self.relmatV={} - for rel in self.getRelations(): - try: self.relmat[rel.getAttribute('label')] - except KeyError:self.relmat[rel.getAttribute('label')]={} - try: self.relmatH[rel.getAttribute('label')] - except KeyError:self.relmatH[rel.getAttribute('label')]={} - srcid,tgtid = rel.getSourceId(), rel.getTargetId() - src = [ x for i,x in enumerate(self.getAllNamedObjects(XMLDSTEXTClass)) if x.getAttribute('id') == srcid] - tgt = [ x for i, x in enumerate(self.getAllNamedObjects(XMLDSTEXTClass)) if x.getAttribute('id') == tgtid] - rel.setSource(src[0]) - try:src[0].lrel.append(tgt[0]) - except: src[0].lrel=[tgt[0]] - try:tgt[0].lrel.append(src[0]) - except: tgt[0].lrel=[src[0]] - rel.setTarget(tgt[0]) - if True:# and rel.getAttribute('type')=='continue': - self.relmat[rel.getAttribute('label')][(src[0].getAttribute('id'),tgt[0].getAttribute('id'))] = float(rel.getAttribute('w')) - self.relmat[rel.getAttribute('label')][(tgt[0].getAttribute('id'),src[0].getAttribute('id'))] = float(rel.getAttribute('w')) - if rel.getAttribute('type')=='HorizontalEdge': - self.relmatH[rel.getAttribute('label')][(src[0].getAttribute('id'),tgt[0].getAttribute('id'))] = float(rel.getAttribute('w')) - self.relmatH[rel.getAttribute('label')][(tgt[0].getAttribute('id'),src[0].getAttribute('id'))] = float(rel.getAttribute('w')) - - - #TEMPLATE # def setVerticalTemplates(self,lvm): # self._verticalZoneTemplates = lvm @@ -223,29 +184,6 @@ def getdVSeparator(self,template): except KeyError: return [] - - def addVerticalChunk(self,elt, chunk): - """ - add vertical hunk built with elements 'elt' - """ - try: self._dVC[elt.name].append(chunk) - except KeyError : self._dVC[elt.name] = [ chunk ] - - def addHorizontalChunk(self,elt, chunk): - """ - add Horizontal chunks built with elements 'elt' - """ - try: self._dHC[elt.name].append(chunk) - except KeyError : self._dHC[elt.name] = [chunk] - - - def getHorizontalChunk(self,elt, ): - """ - get Horizontal chunks built with elements 'elt' - """ - try: return self._dHC[elt.name] - except KeyError : return None - #OBJECTS (from registered cut regions) def addVerticalObject(self,Template,o): @@ -270,8 +208,6 @@ def getHGLFeatures(self): return self._HGLFeatures def setVGLFeatures(self,f): self._VGLFeatures.append(f) def getVGLFeatures(self): return self._VGLFeatures - - def getVX1Info(self): return self._VX1Info def setVX1Info(self,lInfo): """ @@ -280,8 +216,16 @@ def setVX1Info(self,lInfo): """ self._VX1Info = lInfo + def getVX2Info(self): return self._VX2Info + def setVX2Info(self,lInfo): + """ + list of X1 features for the H structure of the page + corresponds to information to segment the page vertically + """ + self._VX2Info = lInfo + def getX1X2(self): return self._X1X2 - def setX1X2(self,x): self._X1X2 = x + def setX1X2(self,x): self._X1X2.append(x) def getVWInfo(self): return self._VWInfo def setVWInfo(self,lInfo): @@ -409,7 +353,6 @@ def createVerticalZones(self,Template,tagLevel=XMLDSTEXTClass): # create regions prevCut=0 - for xcut in self.getdVSeparator(Template): region=XMLDSObjectClass() region.setName('VRegion') @@ -421,23 +364,22 @@ def createVerticalZones(self,Template,tagLevel=XMLDSTEXTClass): prevCut=xcut.getValue() self.addVerticalObject(Template,region) - if self.getdVSeparator(Template): - #last column - region=XMLDSObjectClass() - region.setName('VRegion') - region.addAttribute('x', prevCut) - region.addAttribute('y', 0) - region.addAttribute('height', self.getAttribute('height')) - region.addAttribute('width', str(self.getWidth() - prevCut)) - prevCut=xcut.getValue() - self.addVerticalObject(Template,region) + #last column + region=XMLDSObjectClass() + region.setName('VRegion') + region.addAttribute('x', prevCut) + region.addAttribute('y', 0) + region.addAttribute('height', self.getAttribute('height')) + region.addAttribute('width', str(self.getWidth() - prevCut)) + prevCut=xcut.getValue() + self.addVerticalObject(Template,region) - ## populate regions - for subObject in self.getAllNamedObjects(tagLevel): - region= subObject.bestRegionsAssignment(self.getVerticalObjects(Template)) - if region: - region.addObject(subObject) + ## populate regions + for subObject in self.getAllNamedObjects(tagLevel): + region= subObject.bestRegionsAssignment(self.getVerticalObjects(Template)) + if region: + region.addObject(subObject) def getSetOfMutliValuedFeatures(self,TH,lMyFeatures,myObject): diff --git a/TranskribusDU/ObjectModel/XMLDSTABLEClass.py b/TranskribusDU/ObjectModel/XMLDSTABLEClass.py index 6c5d859..392bb7a 100644 --- a/TranskribusDU/ObjectModel/XMLDSTABLEClass.py +++ b/TranskribusDU/ObjectModel/XMLDSTABLEClass.py @@ -9,18 +9,7 @@ READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -35,13 +24,10 @@ from .XMLDSCELLClass import XMLDSTABLECELLClass from .XMLDSTableColumnClass import XMLDSTABLECOLUMNClass from .XMLDSTableRowClass import XMLDSTABLEROWClass -from .XMLDSTEXTClass import XMLDSTEXTClass from config import ds_xml_def as ds_xml -from copy import deepcopy + import numpy as np -from shapely.geometry import MultiPolygon -from shapely.geometry.collection import GeometryCollection class XMLDSTABLEClass(XMLDSObjectClass): """ @@ -54,7 +40,7 @@ class XMLDSTABLEClass(XMLDSObjectClass): def __init__(self,domNode = None): XMLDSObjectClass.__init__(self) XMLDSObjectClass.id += 1 - self._name = ds_xml.sTABLE + self._domNode = domNode self._lspannedCells=[] @@ -93,16 +79,14 @@ def addCell(self,cell): """ add a cell update row and col data structure? no - but insert cell at the right position + but insert cell at the rihht position """ self._lcells.append(cell) self.addObject(cell) def delCell(self,cell): - try: - self._lcells.remove(cell) -# self._lObjects.remove(cell) - except:pass #print ('cell already deleted?',cell) + try:self._lcells.remove(cell) + except:pass def getNbRows(self): @@ -133,6 +117,8 @@ def displayPerRow(self): for cell in row.getCells(): print(cell, cell.getFields(),end='') print() + + def addColumn(self,col): """ @@ -151,28 +137,9 @@ def addColumn(self,col): return col - def eraseColumns(self): - """ - delete all columns - """ - self._lcolumns = [] - self._nbCols= None - - def eraseRows(self): - """ - delete all rows - """ - self._lrows = [] - self._nbRows= None - - def tagMe(self,sLabel=None): - super(XMLDSObjectClass,self).tagMe(sLabel) - for r in self.getRows():r.tagMe() - for c in self.getColumns():c.tagMe() - def createRowsWithCuts(self,lYCuts,bTakeAll=False): """ - input: horizontal lcuts + input: a table and cells output: list of rows populated with appropriate cells (main overlap) """ if lYCuts == []: @@ -180,31 +147,32 @@ def createRowsWithCuts(self,lYCuts,bTakeAll=False): #reinit rows? yes self._lrows = [] - self._nbRows = None -# lCells =self.getCells() + lCells =self.getCells() prevCut = self.getY() # - try: - lYCuts = list(map(lambda x:x.getValue(),lYCuts)) - except: - pass + try:lYCuts = map(lambda x:x.getValue(),lYCuts) + except:pass irowIndex = 0 for irow,cut in enumerate(lYCuts): -# cut = cut-10 # 10 for ABP row= XMLDSTABLEROWClass(irowIndex) row.setParent(self) - row.setY(prevCut) + row.addAttribute('y',prevCut) # row too narow from table border # if cut.getValue()-prevCut > 0: if bTakeAll or cut - prevCut > 0: # row.addAttribute('height',cut.getValue()-prevCut) - row.setHeight(cut - prevCut) - row.setX(self.getX()) - row.setWidth(self.getWidth()) - row.addAttribute('points',"%s,%s,%s,%s,%s,%s,%s,%s"%(self.getX(), self.getY(),self.getX2(), self.getY(), self.getX2(), self.getY2(), self.getX(), self.getY2())) + row.addAttribute('height',cut - prevCut) + row.addAttribute('x',self.getX()) + row.addAttribute('width',self.getWidth()) + row.tagMe() + for c in lCells: + if c.overlap(row): + row.addObject(c) + c.setIndex(irowIndex,c.getIndex()[1]) +# print irow,c.getIndex() + row.addCell(c) self.addRow(row) -# row.tagMe() irowIndex+=1 else: del(row) @@ -216,61 +184,24 @@ def createRowsWithCuts(self,lYCuts,bTakeAll=False): row= XMLDSTABLEROWClass(irowIndex) row.setParent(self) - row.setY(prevCut) - row.setHeight(self.getY2()-prevCut) - row.setX(self.getX()) - row.setWidth(self.getWidth()) + row.addAttribute('y',prevCut) + row.addAttribute('height',self.getY2()-prevCut) + row.addAttribute('x',self.getX()) + row.addAttribute('width',self.getWidth()) row.addAttribute('index',row.getIndex()) - row.addAttribute('points',"%s,%s,%s,%s,%s,%s,%s,%s"%(self.getX(), self.getY(),self.getX2(), self.getY(), self.getX2(), self.getY2(), self.getX(), self.getY2())) - self.addRow(row) -# row.tagMe() - - ## recreate correctly cells : nb cells: #rows x #col - - - - - def testPopulate0(self): - """ - test shapely library - Take the cells and populate with textlines - """ - lpcell = [cell.toPolygon() for cell in self.getCells()] - for text in self.getAllNamedObjects(XMLDSTEXTClass): - for pcell in lpcell: - pcell.intersection(text.toPolygon()) - - def testPopulate(self): - """ - test shapely library - Take the cells and populate with textlines - """ -# return self.testPopulate0() - from rtree import index + row.tagMe() -# print (len(self.getCells())) -# # create cell index - lIndCells = index.Index() - for pos, cell in enumerate(self.getCells()): - lIndCells.insert(pos, cell.toPolygon().bounds) - - for text in self.getAllNamedObjects(XMLDSTEXTClass): - ll = list(lIndCells.intersection(text.toPolygon().bounds)) - bestcell = None - bestarea = 0 - if ll != []: - print (text.getAttribute("id"),[self.getCells()[i] for i in ll]) - for i,x in enumerate(ll): - a = text.toPolygon().intersection(self.getCells()[x].toPolygon()).area - if a > bestarea: - bestarea = a - bestcell = x - print (text.getAttribute("id"),self.getCells()[bestcell],bestarea,text.toPolygon().area) - - + for c in lCells: + if c.overlap(row): + row.addObject(c) + c.setIndex(irow,c.getIndex()[1]) + row.addCell(c) + self.addRow(row) - def reintegrateCellsInColRow(self,lObj=[]): + ## recreate correctly cells : nb cells: #rows x #col + + def reintegrateCellsInColRow(self): """ after createRowsWithCuts, need for refitting cells in the rows (and merge them) @@ -281,133 +212,41 @@ def reintegrateCellsInColRow(self,lObj=[]): 3- assign text to best cell - - """ - -# import numpy as np - from shapely.geometry import Polygon -# - lCells = [] - NbrowsToDel=0 - lRowWithPB = [] - for icol,col in enumerate(self.getColumns()): - polycol = col.toPolygon().buffer(0) - if not polycol.is_valid: polycol=polycol.convex_hull - rowCells = [] - lColWithPb = [] - - for irow, row in enumerate(self.getRows()): - polyrow= row.toPolygon().buffer(0) - if not polyrow.is_valid :polyrow = polyrow.convex_hull - if polycol.is_valid and polyrow.is_valid and polycol.intersection(polyrow).area> 0.1 :#(polyrow.area*0.25): - cell=XMLDSTABLECELLClass() - inter = polycol.intersection(polyrow) - if not inter.is_valid: inter =inter.convex_hull - x,y,x2,y2 = inter.bounds - cell.setXYHW(x,y, y2-y,x2-x) - ## due to rox/col defined with several lines - if isinstance(inter,MultiPolygon) or isinstance(inter,GeometryCollection): - linter= list(inter.geoms) - linter.sort(key=lambda x:x.area,reverse=True) - inter = linter[0] -# elfi if isinstance(inter,MultiPolygon): - if isinstance(inter,Polygon): - rowCells.append(cell) - ll = list(inter.exterior.coords) - cell.addAttribute('points', " ".join(list("%s,%s"%(x,y) for x,y in ll))) - cell.setIndex(irow, icol) - cell.setPage(self.getPage()) - row.addCell(cell) - col.addCell(cell) - cell.setParent(self) -# print (irow,icol,cell,cell.getAttribute('points')) - else: -# print([x.area for x in inter]) - print (irow,icol,type(inter),list(inter.geoms)) -# sss - else: - print("EMPTY?",polycol.intersection(polyrow).area,irow,icol,polycol.is_valid,polyrow.is_valid,polycol.intersection(polyrow)) -# lColWithPb.add(icol) - lRowWithPB.append(irow) - lColWithPb.append(icol) - #empty cell zone!! -# cell=XMLDSTABLECELLClass() -# rowCells.append(cell) -# # ll = list(inter.exterior.coords) -# #cell.addAttribute('points', " ".join(list("%s,%s"%(x,y) for x,y in ll))) -# cell.setIndex(irow, icol) -# cell.setPage(self.getPage()) -# row.addCell(cell) -# col.addCell(cell) -# cell.setParent(self) -# print (lColWithPb,lRowWithPB) - if False and len(lColWithPb) >0 and len(lRowWithPB)>len(lColWithPb): -# print (len(lColWithPb),len(lRowWithPB)) - NbrowsToDel+=1 - else: - lCells.extend(rowCells) - for text in lObj: - cell = text.bestRegionsAssignment(lCells,bOnlyBaseline=False) -# print(text.getContent(),cell) - if cell: - cell.addObject(text,bDom=True) - - - - #update with real cells - self._lcells = lCells -# print (len(lCells)) - # DOM tagging -# for cell in self.getCells(): -# cell.tagMe2() -# for o in cell.getObjects(): -# try:o.tagMe() -# except AttributeError:pass - -# # update rows!!!! -# print (self.getRows()[1].getX()) -# self.buildRowFromCells() -# print (self.getRows()) -# ss - def reintegrateCellsInColRowold(self): - """ - after createRowsWithCuts, need for refitting cells in the rows (and merge them) - - 1- create new regular cells from rows and columns - 2- for each cell: - for each text: compute overlap - store best for text - 3- assign text to best cell - + ## populate regions + for subObject in self.getAllNamedObjects(tagLevel): + region= subObject.bestRegionsAssignment(self.getVerticalObjects(Template)) + if region: + region.addObject(subObject) - """ # for cell in self.getCellsbyColumns(): # print self.getNbRows(), self.getNbColumns() lCells = [] for icol,col in enumerate(self.getColumns()): - lcolTexts = [] - [ lcolTexts.extend(colcell.getObjects()) for colcell in col.getObjects()] - # texts tagged OTHER as well? - rowCells = [] + lTexts = [] + [ lTexts.extend(colcell.getObjects()) for colcell in col.getObjects()] + # texts tagged OTHER as well? for irow, row in enumerate(self.getRows()): +# print icol,irow,row, row.getObjects() +# print 'cell:', col.getX(),row.getY(),col.getX2(),row.getY2() cell=XMLDSTABLECELLClass() - rowCells.append(cell) - cell.setXYHW(col.getX(),row.getY(), row.getY2() - row.getY(),col.getX2() - col.getX()) + lCells.append(cell) + cell.addAttribute('x',col.getX()) + cell.addAttribute('y',row.getY()) + cell.addAttribute('height',row.getY2() - row.getY()) + cell.addAttribute('width',col.getX2() - col.getX()) cell.setIndex(irow, icol) cell.setPage(self.getPage()) cell.setParent(self) - row.addCell(cell) - col.addCell(cell) - ## spanning information missing!!!! - - for text in lcolTexts: - cell = text.bestRegionsAssignmentOld(rowCells,bOnlyBaseline=False) + # find the best assignment of each text + for text in lTexts: + cell = text.bestRegionsAssignment(lCells) if cell: cell.addObject(text,bDom=True) - - lCells.extend(rowCells) +# else: +# print text,text.getX(),text.getY(),cell + #delete fake cells for cell in self.getCells(): # cell.getNode().unlinkNode() @@ -426,8 +265,8 @@ def reintegrateCellsInColRowold(self): try:o.tagMe() except AttributeError:pass - # update rows!! - #self.buildRowFromCells() + # update rows!!!! + self.buildRowFromCells() def buildRowFromCells(self): @@ -437,10 +276,12 @@ def buildRowFromCells(self): Rowspan: create row """ self._lrows=[] + self.getCells().sort(key=(lambda x:x.getIndex()[0])) for cell in self.getCells(): irow,_= cell.getIndex() # rowSpan = int(cell.getAttribute('rowSpan')) + try: row = self.getRows()[irow] except IndexError: row = XMLDSTABLEROWClass(irow) @@ -452,7 +293,7 @@ def buildRowFromCells(self): for row in self.getRows(): row.resizeMe(XMLDSTABLECELLClass) -# print (row.getIndex(),row.getParent()) +# print row.tagMe(row.tagname) def buildColumnFromCells(self): """ @@ -461,7 +302,6 @@ def buildColumnFromCells(self): self.getCells().sort(key=(lambda x:x.getIndex()[1])) self._lcolumns= [] for cell in self.getCells(): -# print (cell, cell.getRowSpan(), cell.getColSpan(), cell.getObjects()) _,jcol= cell.getIndex() try: col = self.getColumns()[jcol] except IndexError: @@ -471,11 +311,10 @@ def buildColumnFromCells(self): # print col.getPage(), self.getPage() if col is not None:col.addCell(cell) +# print self.getColumns() for col in self.getColumns(): - try: - col.resizeMe(XMLDSTABLECELLClass) - except:pass + col.resizeMe(XMLDSTABLECELLClass) # node = col.tagMe(col.tagname) @@ -489,56 +328,46 @@ def buildColumnRowFromCells(self): # print 'nb cells:' , len(self.getCells()) # first despan RowSpan cells self.getCells().sort(key=(lambda x:x.getIndex()[0])) - lNewCells = [] for cell in self.getCells(): # create new non-spanned cell if needed # print cell, cell.getRowSpan(), cell.getColSpan() - iRowSpan = cell.getIndex()[0] - while iRowSpan < cell.getIndex()[0] + cell.getRowSpan(): + iRowSpan = 1 + while iRowSpan < cell.getRowSpan(): # print 'row:', cell, cell.getRowSpan(),iRowSpan newCell = XMLDSTABLECELLClass(cell.getNode()) newCell.setName(XMLDSTABLECELLClass.name) newCell.setPage(self.getPage()) newCell.setParent(self) newCell.setObjectsList(cell.getObjects()) - newCell._lAttributes = deepcopy(cell.getAttributes()) - newCell.copyXYHW(cell) - newCell.addAttribute('rowSpan',newCell.getRowSpan()) -# newCell.setIndex(newCell.getIndex()[0]+iRowSpan, newCell.getIndex()[1]) - newCell.setIndex(iRowSpan, cell.getIndex()[1]) + newCell._lAttributes = cell.getAttributes().copy() + newCell.addAttribute('rowSpan',1) + newCell.setIndex(cell.getIndex()[0]+iRowSpan, cell.getIndex()[1]) newCell.setSpannedCell(cell) -# cell.setSpannedCell(cell) + cell.setSpannedCell(cell) newCell.bDeSpannedRow = True - lNewCells.append(newCell) + self.addCell(newCell) iRowSpan +=1 - # col span #sort them by col? -# self.getCells().sort(key=(lambda x:x.getIndex()[1])) - lNewCells.sort(key=(lambda x:x.getIndex()[1])) - lNewCells2 = [] - for cell in lNewCells: #self.getCells(): + self.getCells().sort(key=(lambda x:x.getIndex()[1])) + for cell in self.getCells(): # create new non-spanned cell if needed - iColSpan = cell.getIndex()[1] - while iColSpan < cell.getIndex()[1] + cell.getColSpan(): + iColSpan = 1 + while iColSpan < cell.getColSpan(): +# print 'col:', cell, cell.getColSpan(), iColSpan newCell = XMLDSTABLECELLClass(cell.getNode()) newCell.setName(XMLDSTABLECELLClass.name) newCell.setParent(self) - newCell._lAttributes = deepcopy(cell.getAttributes()) - newCell.copyXYHW(cell) + newCell._lAttributes = cell.getAttributes().copy() newCell.setObjectsList(cell.getObjects()) - newCell.addAttribute('colSpan',newCell.getColSpan()) -# newCell.setIndex(newCell.getIndex()[0], newCell.getIndex()[1]+iColSpan) - newCell.setIndex(cell.getIndex()[0], iColSpan) + newCell.addAttribute('colSpan',1) + newCell.setIndex(cell.getIndex()[0], cell.getIndex()[1]+iColSpan) newCell.setSpannedCell(cell) -# cell.setSpannedCell(cell) + cell.setSpannedCell(cell) newCell.bDeSpannedCol = True -# cell.bDeSpannedCol = True - lNewCells2.append(newCell) + self.addCell(newCell) # print '\tnex col cell:',newCell, iColSpan+1, cell.getColSpan() iColSpan +=1 -# self.getCells().extend(lNewCells) - self._lcells = lNewCells # print '-- nb cells:', len(self.getCells()) def assignElementsToCells(self): @@ -621,11 +450,6 @@ def fromDom(self,domNode): # get properties for prop in domNode.keys(): self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) - # ctxt = domNode.doc.xpathNewContext() # ctxt.setContextNode(domNode) @@ -639,14 +463,12 @@ def fromDom(self,domNode): myObject.setPage(self.getPage()) myObject.fromDom(elt) -# self._lspannedCells = self._lcells[:] + self._lspannedCells = self._lcells[:] self.buildColumnRowFromCells() self.buildColumnFromCells() self.buildRowFromCells() self.getCellsbyRow() - - # self.displayPerRow() # print self.getNbRows(), self.getNbColumns() diff --git a/TranskribusDU/ObjectModel/XMLDSTEXTClass.py b/TranskribusDU/ObjectModel/XMLDSTEXTClass.py index 211c5db..702e5eb 100644 --- a/TranskribusDU/ObjectModel/XMLDSTEXTClass.py +++ b/TranskribusDU/ObjectModel/XMLDSTEXTClass.py @@ -36,10 +36,10 @@ def __init__(self,domNode = None): # def getX(self): return float(self.getAttribute('x')) # def getY(self): return float(self.getAttribute('y')) -# def getX2(self): -# return float(self.getAttribute('x'))+self.getWidth() -# def getY2(self): -# return float(self.getAttribute('y'))+self.getHeight() + def getX2(self): + return float(self.getAttribute('x'))+self.getWidth() + def getY2(self): + return float(self.getAttribute('y'))+self.getHeight() # def getHeight(self): return float(self.getAttribute('height')) # def getWidth(self): return float(self.getAttribute('width')) @@ -53,21 +53,13 @@ def fromDom(self,domNode): # self.setName(domNode.atg) self.setNode(domNode) # get properties -# for prop in domNode.keys(): -# self.addAttribute(prop,domNode.get(prop)) + for prop in domNode.keys(): + self.addAttribute(prop,domNode.get(prop)) try: self._id = self.getAttribute('id') except:pass - - for prop in domNode.keys(): - self.addAttribute(prop,domNode.get(prop)) - if prop =='x': self._x= float(domNode.get(prop)) - elif prop =='y': self._y = float(domNode.get(prop)) - elif prop =='height': self._h = float(domNode.get(prop)) - elif prop =='width': self.setWidth(float(domNode.get(prop))) - - self.addAttribute('x2', self.getX()+self.getWidth()) - self.addAttribute('y2',self.getY()+self.getHeight() ) + self.addAttribute('x2', float(self.getAttribute('x'))+self.getWidth()) + self.addAttribute('y2',float(self.getAttribute('y'))+self.getHeight() ) if self.hasAttribute('blpoints'): from ObjectModel.XMLDSBASELINEClass import XMLDSBASELINEClass @@ -85,7 +77,7 @@ def fromDom(self,domNode): pass else: try:txt=txt.decode('utf-8') - except AttributeError as e: + except AttributeError: pass if self.getContent() is not None: self.addContent(txt) @@ -119,7 +111,7 @@ def computeBaseline(self): try: lX.append(token.getX()) lX.append(token.getX2()) - lY.append(token.getY()) + lY.append(token.getY2()) lY.append(token.getY2()) except TypeError: pass @@ -305,8 +297,9 @@ def getSetOfListedAttributes(self,TH,lAttributes,myObject): feature.addNode(self) feature.setObjectName(self) # avg of baseline? - avg1= baseline.getY() +(baseline.getY2() -baseline.getY())/2 + avg1= baseline.getY()+(baseline.getY2() -baseline.getY())/2 avg2= nbl.getY() +(nbl.getY2()-nbl.getY())/2 + feature.setValue(round(abs(avg2-avg1))) feature.setType(ftype) self.addFeature(feature) diff --git a/TranskribusDU/ObjectModel/XMLDSTableColumnClass.py b/TranskribusDU/ObjectModel/XMLDSTableColumnClass.py index 0072406..22b457b 100644 --- a/TranskribusDU/ObjectModel/XMLDSTableColumnClass.py +++ b/TranskribusDU/ObjectModel/XMLDSTableColumnClass.py @@ -9,18 +9,7 @@ READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -32,19 +21,20 @@ from __future__ import unicode_literals from .XMLDSObjectClass import XMLDSObjectClass +from config import ds_xml_def as ds_xml class XMLDSTABLECOLUMNClass(XMLDSObjectClass): """ Column class """ - name = 'COL' + name = ds_xml.sCOL_Elt def __init__(self,index=None,domNode = None): XMLDSObjectClass.__init__(self) XMLDSObjectClass.id += 1 self._domNode = domNode self._index= index self._lcells=[] -# self.tagname= 'COL' + self.tagName= 'COL' self.setName(XMLDSTABLECOLUMNClass.name) def __repr__(self): @@ -56,9 +46,8 @@ def getIndex(self): return self._index def setIndex(self,i): self._index = i def delCell(self,cell): - self._lcells.remove(cell) -# try:self._lcells.remove(cell) -# except:pass + try:self._lcells.remove(cell) + except:pass def getCells(self): return self._lcells def addCell(self,c): if c not in self.getCells(): diff --git a/TranskribusDU/ObjectModel/XMLDSTableRowClass.py b/TranskribusDU/ObjectModel/XMLDSTableRowClass.py index ac4807d..cec2430 100644 --- a/TranskribusDU/ObjectModel/XMLDSTableRowClass.py +++ b/TranskribusDU/ObjectModel/XMLDSTableRowClass.py @@ -20,7 +20,7 @@ class XMLDSTABLEROWClass(XMLDSObjectClass): LINE class """ name = ds_xml.sROW - def __init__(self,index=None,domNode = None): + def __init__(self,index,domNode = None): XMLDSObjectClass.__init__(self) XMLDSObjectClass.id += 1 self._domNode = domNode @@ -47,69 +47,6 @@ def addCell(self,c): # c.getNode().unlinkNode() # self.getNode().addChild(c.getNode()) - - def computeSkewing(self): - """ - input: self - output: skewing ange - compute text skewing in the row - """ - def getX(lSegment): - lX = list() - for x1,y1,x2,y2 in lSegment: - lX.append(x1) - lX.append(x2) - return lX - - def getY(lSegment): - lY = list() - for x1,y1,x2,y2 in lSegment: - lY.append(y1) - lY.append(y2) - return lY - - import numpy as np - from util.Polygon import Polygon - if self.getCells(): - dRowSep_lSgmt=[] - # alternative: compute the real top ones? for wedding more robust!! - for cell in self.getCells(): -# lTopText = filter(lambda x:x.getAttribute('DU_row') == 'B', [text for text in cell.getObjects()]) - try:lTopText = [cell.getObjects()[0]] - except IndexError:lTopText = [] - for text in lTopText: - sPoints = text.getAttribute('points') - spoints = ' '.join("%s,%s"%((x,y)) for x,y in zip(*[iter(sPoints.split(','))]*2)) - it_sXsY = (sPair.split(',') for sPair in spoints.split(' ')) - plgn = Polygon((float(sx), float(sy)) for sx, sy in it_sXsY) - - lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() - dRowSep_lSgmt.extend(lT) - if dRowSep_lSgmt != []: - X = getX(dRowSep_lSgmt) - Y = getY(dRowSep_lSgmt) - lfNorm = [np.linalg.norm([[x1,y1], [x2,y2]]) for x1,y1,x2,y2 in dRowSep_lSgmt] - #duplicate each element - W = [fN for fN in lfNorm for _ in (0,1)] - - # a * x + b - a, b = np.polynomial.polynomial.polyfit(X, Y, 1, w=W) - xmin, xmax = min(X), max(X) - y1 = a + b * xmin - y2 = a + b * xmax - ro = XMLDSTABLEROWClass(self.getIndex()) - #[(x1, ymin), (x2, ymax) - ro.setX(xmin) - ro.setY(y1) - ro.setWidth(xmax-xmin) - ro.setHeight(y2-y1) - ro.setPage(self.getPage()) - ro.setParent(self.getParent()) - ro.addAttribute('points',','.join([str(xmin),str(y1),str(xmax),str(y2)])) - ro.tagMe() - - - ########## TAGGING ############## def addField(self,tag): diff --git a/TranskribusDU/ObjectModel/XMLObjectClass.py b/TranskribusDU/ObjectModel/XMLObjectClass.py index 82ef928..e201554 100644 --- a/TranskribusDU/ObjectModel/XMLObjectClass.py +++ b/TranskribusDU/ObjectModel/XMLObjectClass.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- +# -*- coding: latin-1 -*- """ XML object class - Herv� D�jean + Hervé Déjean cpy Xerox 2009 a class for object from a XMLDocument @@ -72,9 +72,7 @@ def tagMe(self,sLabel=None): """ create a dom elt and add it to the doc """ - if sLabel is None: - newNode = etree.Element(self.getName()) - else: newNode = etree.Element(sLabel) + newNode = etree.Element(self.tagName) newNode.set('x',str(self.getX())) newNode.set('y',str(self.getY())) @@ -103,6 +101,7 @@ def tagMe(self,sLabel=None): # for o in self.getObjects(): # o.setParent(self) # o.tagMe() + return newNode def fromDom(self,domNode): diff --git a/TranskribusDU/ObjectModel/objectClass.py b/TranskribusDU/ObjectModel/objectClass.py index 9703a1d..13576d4 100644 --- a/TranskribusDU/ObjectModel/objectClass.py +++ b/TranskribusDU/ObjectModel/objectClass.py @@ -27,9 +27,6 @@ def __init__(self): # sub object self._lObjects = [] - #relations - self._lRelations =[] - # characteristics self._lAttributes = {} @@ -68,10 +65,6 @@ def addObject(self,o): self.getObjects().append(o) o.setParent(self) - def getRelations(self): return self._lRelations - def addRelation(self,r): - self._lRelations.append(r) - def addContent(self,c): c=c.replace(u'\n',u' ') if self.getContent() is not None: @@ -157,9 +150,6 @@ def display(self,level=0): ########### IE part ########### ### move to objectClass? - - def resetField(self):self._lFields=[] - def addField(self,field,value=None): """ add field (record field) to this cell: this cell is supposed to contain such a field diff --git a/TranskribusDU/ObjectModel/recordClass.py b/TranskribusDU/ObjectModel/recordClass.py index d364618..77308a1 100644 --- a/TranskribusDU/ObjectModel/recordClass.py +++ b/TranskribusDU/ObjectModel/recordClass.py @@ -7,18 +7,7 @@ READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -198,7 +187,7 @@ class fieldClass(object): def __init__(self,name=None): self._name = name # allow for muti-value, range,... - self._value = [] + self._value = None # backref to record self._record = None @@ -231,7 +220,7 @@ def setName(self,s): self._name =s def getValue(self): return self._value def setValue(self,v): self._value = v - def addValue(self,v): self._value.extend(v) + def setRecord(self,r): self._record = r def getRecord(self):return self._record @@ -265,7 +254,7 @@ def extractLabel(self,lres): def getBestValue(self): # old (u'List', (2, 0, 0), u'Ritt', 987) # now [(u'Theresia', 0.9978103), (u'Sebald',0.71877468)] - if self.getValue() != []: + if self.getValue() is not None: # score = list! take max self.getValue().sort(key = lambda x:max(x[1]),reverse=True) return self.getValue()[0][0] @@ -394,44 +383,6 @@ def runMe(self,documentObject): return res -class KerasTagger2(taggerClass): - """ - see taggerTrainKeras - -> use directly DeepTagger? - """ - def __init__(self,name): - taggerClass.__init__(self, name) - self.myTagger = DeepTagger() - self.myTagger.bPredict = True -# self.myTagger.sModelName = None -# self.myTagger.dirName = 'IE -# self.myTagger.loadModels() - - def loadResources(self,sModelName,dirName): - # location from sModeName, dirName - self.myTagger.sModelName = sModelName - self.myTagger.bAttentionLayer = sModelName[-3:] == 'att' - self.myTagger.dirName = dirName - self.myTagger.loadModels() - - def runMe(self,documentObject): - ''' - delete '.' because of location in GT - ''' -# res = self.myTagger.predict([documentObject.getContent()]) -# return res - - if documentObject.getContent() is None: - return [] - if self.myTagger.bMultiType: - res = self.myTagger.predict_multiptype([documentObject.getContent()]) - else: -# res = self.myTagger.predict([documentObject.getContent().replace('.','')]) - res = self.myTagger.predict([documentObject.getContent()]) - - return res - - class CRFTagger(taggerClass): """ diff --git a/TranskribusDU/ObjectModel/sequenceAPI.py b/TranskribusDU/ObjectModel/sequenceAPI.py index 5511a7d..69db517 100644 --- a/TranskribusDU/ObjectModel/sequenceAPI.py +++ b/TranskribusDU/ObjectModel/sequenceAPI.py @@ -86,8 +86,8 @@ def getCanonicalFeatures(self): return self._canonicalFeatures def addCanonicalFeatures(self,f ): if self._canonicalFeatures is None: self._canonicalFeatures=[] -# if f not in self._canonicalFeatures: - self._canonicalFeatures.append(f) + if f not in self._canonicalFeatures: + self._canonicalFeatures.append(f) diff --git a/TranskribusDU/ObjectModel/tableTemplateClass.py b/TranskribusDU/ObjectModel/tableTemplateClass.py index 74ed678..05beb8d 100644 --- a/TranskribusDU/ObjectModel/tableTemplateClass.py +++ b/TranskribusDU/ObjectModel/tableTemplateClass.py @@ -118,10 +118,8 @@ def labelTable(self,table): for sslice,_, lFields in self._lLabellingInstruction: for field in lFields: if field is not None: - try: - for cell in np.nditer(table.getNPArray()[sslice],['refs_ok'],op_dtypes=np.dtype(object)): - cell[()].addField(field.cloneMe()) - except: pass + for cell in np.nditer(table.getNPArray()[sslice],['refs_ok'],op_dtypes=np.dtype(object)): + cell[()].addField(field.cloneMe()) def registration(self,o): raise "SOFTWARE ERROR: your component must define a testRun method" @@ -201,4 +199,4 @@ def tagDom(self,dom): - + \ No newline at end of file diff --git a/TranskribusDU/ObjectModel/treeTemplateClass.py b/TranskribusDU/ObjectModel/treeTemplateClass.py index bbd5b7b..b2bf522 100644 --- a/TranskribusDU/ObjectModel/treeTemplateClass.py +++ b/TranskribusDU/ObjectModel/treeTemplateClass.py @@ -13,11 +13,8 @@ from __future__ import print_function from __future__ import unicode_literals -import numpy as np -from scipy.optimize import linear_sum_assignment - from .templateClass import templateClass -from spm.frechet import frechetDist +import numpy as np class treeTemplateClass(templateClass): """ @@ -122,99 +119,6 @@ def findTemplatePartFromPattern(self,pattern): return None - def findBestMatch3(self,lRegCuts,lCuts): - """ - lcs! - """ - # Dynamic programming implementation of LCS problem - - # Returns length of LCS for X[0..m-1], Y[0..n-1] - def lcs(X, Y, m, n): - L = [[0 for x in range(n+1)] for x in range(m+1)] - - # Following steps build L[m+1][n+1] in bottom up fashion. Note - # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] - for i in range(m+1): - for j in range(n+1): - if i == 0 or j == 0: - L[i][j] = 0 - elif X[i-1] == Y[j-1]: - L[i][j] = L[i-1][j-1] + 1 - else: - L[i][j] = max(L[i-1][j], L[i][j-1]) - - # Following code is used to print LCS - index = L[m][n] - # Create a character array to store the lcs string - lcs = [""] * (index+1) - lcs[index] = "" - lmapping=[] - # Start from the right-most-bottom-most corner and - # one by one store characters in lcs[] - i = m - j = n - while i > 0 and j > 0: - - # If current character in X[] and Y are same, then - # current character is part of LCS - if X[i-1] == Y[j-1]: - lcs[index-1] = X[i-1] - lmapping.append((i-1,j-1)) - i-=1 - j-=1 - index-=1 - - # If not same, then find the larger of two and - # go in the direction of larger value - elif L[i-1][j] > L[i][j-1]: - i-=1 - else: - j-=1 - -# print ("LCS of " , X , " and " , Y , " is " ,lcs,[(lRegCuts[x],lCuts[y]) for x,y in lmapping]) -# xx =[(lRegCuts[x],lCuts[y]) for x,y in lmapping] - lmapping.reverse() - return lmapping - -# print (lRegCuts,lCuts,lcs(lRegCuts,lCuts,len(lRegCuts),len(lCuts))) - lmap = lcs(lRegCuts,lCuts,len(lRegCuts),len(lCuts)) - if lmap ==[]: return [],[],None - reg,cut = list(zip(*lmap)) -# print (reg,cut) - return reg, cut, None - - def findBestMatch2(self,lRegCuts,lCuts): - """ - best match using hungarian - add a threshold! - """ - cost_matrix=np.zeros((len(lRegCuts),len(lCuts)),dtype=float) - -# print (lRegCuts,lCuts) -# print ([(x,x.getWeight()) for x in lRegCuts]) -# print ([(x,x.getWeight()) for x in lCuts]) - for a,refx in enumerate(lRegCuts): - for b,x in enumerate(lCuts): - dist = refx.getDistance(x) # / abs(x.getWeight() + refx.getWeight()) -# except ZeroDivisionError: -# print (x,x.getWeight()) -# dist =1000 - cost_matrix[a,b]=dist - - r1,r2 = linear_sum_assignment(cost_matrix) - - ltobeDel=[] - for a,i in enumerate(r2): - #if cost is too high: cut the assignment? -# print (a,i,r1,r2,lRegCuts[a],lCuts[i], cost_matrix[a,i]) - if cost_matrix[a,i] > 100: - ltobeDel.append(a) - r2 = np.delete(r2,ltobeDel) - r1 = np.delete(r1,ltobeDel) -# print("wwww",r1,r2) - # score Fréchet distance etween two mapped sequences -# self.findBestMatch3(lRegCuts, lCuts) - return r1,r2,None def findBestMatch(self,lRegCuts,lCuts): """ @@ -285,37 +189,17 @@ def buildObs(lRegCuts,lCuts): states,score = d.Decode(np.arange(len(lCuts))) # print "dec",score, states # print map(lambda x:(x,x.getCanonical().getWeight()),lCuts) -# print (states, type(states[0])) +# print states # for i,si in enumerate(states): # print lCuts[si],si # print obs[si,:] # return the best alignment with template return states, score - - - def computeScore(self,p,q): - """ - input: two lists of pairwise features - does not work !!! - - -> must take into account self and other (not fearures) - """ - - return 1/(1+sum(x.getDistance(y) for x,y in zip(p,q))) - - -# return p[0].getDistance(q[0]) - d =frechetDist(list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q))) -# print ("***",d,list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q))) - if d == 0: - return 1.0 - return 1/(frechetDist(list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q)))) - - def computeScoreold(self,patLen,lReg,lMissed,lCuts): + def computeScore(self,patLen,lReg,lMissed,lCuts): """ - it seems better not to use canonical: thus score bet ter reflects the page + it seems better not to use canonical: thus score better reflects the page also for REF 130 129 is better than 150 """ @@ -362,40 +246,44 @@ def selectBestUniqueMatch(self,lFinres): def registration(self,anobject): """ 'register': match the model to an object - can only be a terminal template + can only a terminal template """ lobjectFeatures = anobject.lFeatureForParsing # lobjectFeatures = anobject._fullFeaturesx # print "?",anobject, lobjectFeatures, self # empty object - if lobjectFeatures == [] or lobjectFeatures is None: + if lobjectFeatures == []: return None,None,-1 # print self.getPattern(), lobjectFeatures - try: self.getPattern().sort(key=lambda x:x.getValue()) - except: pass ## P3 < to be defined for featureObject -# print ('\t',self.getPattern(), anobject, lobjectFeatures) - foundReg, bestReg, _ = self.findBestMatch3(self.getPattern(), lobjectFeatures) - -# bestReg, _ = self.findBestMatch(self.getPattern(), lobjectFeatures) + self.getPattern().sort(key=lambda x:x.getValue()) +# print self.getPattern(), anobject, lobjectFeatures + bestReg, curScore = self.findBestMatch(self.getPattern(), lobjectFeatures) # print bestReg, curScore - if len(bestReg) > 0: - lFinres = list ( zip([(lobjectFeatures[i]) for i in bestReg], ([self.getPattern()[i] for i in foundReg])) ) - score1 = self.computeScore([(self.getPattern()[i]) for i in foundReg], [(lobjectFeatures[i]) for i in bestReg]) -# score1 = self.computeScore(self.getPattern(), lobjectFeatures) -# print ('\t\t',score1) - # how much of the element is covered ? use weight for this - w1 = sum([x.getWeight() for x in [(lobjectFeatures[i]) for i in bestReg]]) - w2 = sum([x.getWeight() for x in lobjectFeatures]) - -# score1 = score1 * ( len(foundReg) + len(bestReg) ) / (len(self.getPattern()) + len(lobjectFeatures)) - score1 = score1 * (w1/w2) - -# score1 = (2 * len(foundReg) ) / (len(self.getPattern()) + len(lobjectFeatures)) - - return lFinres, None, score1 + ltmp = self.getPattern()[:] + ltmp.append('EMPTY') + lMissingIndex = list(filter(lambda x: x not in bestReg, range(0,len(self.getPattern())+1))) + lMissing = np.array(ltmp)[lMissingIndex].tolist() + lMissing = list(filter(lambda x: x!= 'EMPTY',lMissing)) + result = np.array(ltmp)[bestReg].tolist() + lFinres= list(filter(lambda xy: xy[0]!= 'EMPTY',zip(result, lobjectFeatures))) +# print map(lambda x:(x,x.getWeight()),self.getPattern()) + if lFinres != []: + lFinres = self.selectBestUniqueMatch(lFinres) +# print lFinres + score1 = self.computeScore(len(self.getPattern()), lFinres, lMissing,lobjectFeatures) + # for estimating missing? +# self.selectBestAnchor(lFinres) + return lFinres,lMissing,score1 else: return None,None,-1 + + + + + + + diff --git a/TranskribusDU/ObjectModel/verticalZonesTemplateClass.py b/TranskribusDU/ObjectModel/verticalZonesTemplateClass.py index cf1862e..2aa5e2e 100644 --- a/TranskribusDU/ObjectModel/verticalZonesTemplateClass.py +++ b/TranskribusDU/ObjectModel/verticalZonesTemplateClass.py @@ -14,11 +14,8 @@ from __future__ import print_function from __future__ import unicode_literals -import numpy as np -from scipy.optimize import linear_sum_assignment - -from spm.frechet import frechetDist from .templateClass import templateClass +import numpy as np class verticalZonestemplateClass(templateClass): """ @@ -87,71 +84,147 @@ def isRegularGrid(self): return False - def findBestMatch2(self,lRegCuts,lCuts): + def findBestMatch(self,calibration,lRegCuts,lCuts): """ - best match using hungarian - add a threshold! + find the best solution assuming reg=x + dynamic programing (viterbi path) + + score needs to be normalized (0,1) """ - cost_matrix=np.zeros((len(lRegCuts),len(lCuts)),dtype=float) + def buildObs(calibration,lRegCuts,lCuts): + N=len(lRegCuts)+1 + obs = np.zeros((N,len(lCuts)), dtype=np.float16)+ 0.0 + for i,refx in enumerate(lRegCuts): + for j,x in enumerate(lCuts): +# print refx, x, (x.getValue()-calibration), abs((x.getValue()-calibration)-refx.getValue()) + if abs((x.getValue()-calibration)-refx.getValue()) < 20: +# print "\t",refx, x, (x.getValue()-calibration) + obs[i,j]= x.getCanonical().getWeight() * ( 20 - ( abs(x.getValue()-calibration-refx.getValue()))) / 20.0 + elif abs((x.getValue()-calibration)-refx.getValue()) < 40: + obs[i,j]= x.getCanonical().getWeight() * (( 40 - ( abs(x.getValue()-calibration-refx.getValue()))) / 40.0) + else: + # go to empty state + obs[-1,j] = 1.0 + if np.isinf(obs[i,j]): +# print i,j,score + obs[i,j]=64000 + if np.isnan(obs[i,j]): +# print i,j,score + obs[i,j]=10e-3 +# print lRegCuts, lCuts, normalized(obs) + return obs / np.amax(obs) + + import spm.viterbi as viterbi + + # add 'missing' state + N =len(lRegCuts)+1 + transProb = np.zeros((N,N), dtype = np.float16) + for i in range(N-1): +# for j in range(i,N): + transProb[i,i+1]=1.0 #/(N-i) + transProb[:,-1,]=1.0 #/(N) + transProb[-1,:]=1.0 #/(N) + initialProb = np.ones(N) + initialProb = np.reshape(initialProb,(N,1)) + + obs = buildObs(calibration,lRegCuts,lCuts) + d = viterbi.Decoder(initialProb, transProb, obs) + states,score = d.Decode(np.arange(len(lCuts))) +# print map(lambda x:(x,x.getCanonical().getWeight()),lCuts) +# print states +# for i,si in enumerate(states): +# print lCuts[si],si +# print obs[si,:] - for a,refx in enumerate(lRegCuts): - for b,x in enumerate(lCuts): - dist = refx.getDistance(x) - cost_matrix[a,b]=dist + # return the best alignment with template + return states, score + + + def selectBestAnchor(self,lCuts): + """ + select the best anchor and use width for defining the other? + """ + fShort = 9e9 + bestElt = None + for i,(x,y) in enumerate(lCuts): + if abs(x.getValue() - y.getValue()) < fShort: + bestElt=(x,y) + fShort = abs(x.getValue() - y.getValue()) + + print ('BEST', bestElt) - r1,r2 = linear_sum_assignment(cost_matrix) - - ltobeDel=[] - for a,i in enumerate(r2): - #if cost is too high: cut the assignment? - print (a,i,r1,r2,lRegCuts[a],lCuts[i], cost_matrix[a,i]) - if cost_matrix[a,i] > 100: - ltobeDel.append(a) - r2 = np.delete(r2,ltobeDel) - r1 = np.delete(r1,ltobeDel) -# print ('\t',r1,r2,ltobeDel,lRegCuts,lCuts) - # score Fréchet distance etween two mapped sequences - return r1,r2,None - - - def computeScore(self,p,q): - d =frechetDist(list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q))) -# print (d,list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q))) - if d == 0: - return 1 - return 1/(frechetDist(list(map(lambda x:(x.getValue(),0),p)),list(map(lambda x:(x.getValue(),0),q)))) - - def registration(self,anobject): + + def selectBestCandidat(self,lCuts): """ - 'register': match the model to an object - can only a terminal template + if several x are selected for a 'state': take the nearest one + possible improvement: consider width, weight """ - lobjectFeatures = anobject.lFeatureForParsing -# lobjectFeatures = anobject._fullFeaturesx - print (anobject, lobjectFeatures, self) - # empty object - if lobjectFeatures == []: + lFinal=[] + dBest = {} + for x,y in lCuts: + try: + if abs(x.getValue() - dBest[x].getValue()) > abs(x.getValue() - y.getValue()): + dBest[x]=y + except KeyError: + dBest[x]=y + for x,y in lCuts: + lFinal.append((x,dBest[x])) + return lFinal + + def computeScore(self,lReg,lCuts): + fFound= 1.0 * sum(map(lambda (r,x):x.getCanonical().getWeight(),lReg)) + fTotal = 1.0 * sum(map(lambda x:x.getCanonical().getWeight(),lCuts)) +# print '========' +# print map(lambda x:(x,x.getCanonical().getWeight()),lCuts) +# +# print fFound , map(lambda (r,x):x.getCanonical().getWeight(),lReg) +# print fTotal, map(lambda x:x.getCanonical().getWeight(),lCuts) + return fFound/fTotal + + def registration(self,pageObject): + """ + using lCuts (and width) for positioning the page + return the registered values + """ + if pageObject.lf_XCut == []: return None,None,-1 -# print self.getPattern(), lobjectFeatures - try: self.getPattern().sort(key=lambda x:x.getValue()) - except: pass ## P3 < to be defined for featureObject -# print self.getPattern(), anobject, lobjectFeatures - foundReg,bestReg, _ = self.findBestMatch2(self.getPattern(), lobjectFeatures) - -# bestReg, _ = self.findBestMatch(self.getPattern(), lobjectFeatures) -# print bestReg, curScore - if bestReg != []: - lFinres = list(zip([(lobjectFeatures[i]) for i in bestReg], ([self.getPattern()[i] for i in foundReg]))) -# print (lFinres) -# score1 = self.computeScore(len(self.getPattern()), lFinres, [],lobjectFeatures) -# print (bestReg, self.getPattern(),[(self.getPattern()[i]) for i in bestReg]) -# score1 = self.computeScore([(self.getPattern()[i]) for i in foundReg], lobjectFeatures) - score1 = self.computeScore(self.getPattern(), lobjectFeatures) + # define lwidth for the page + pageObject.lf_XCut.sort(key=lambda x:x.getValue()) +# print pageObject, pageObject.lf_XCut +# print self.getXCuts() + + ## define a set of interesting calibration +# lCalibration= [0,-50,50] + lCalibration= [0] + + + bestScore=0 + bestReg=None + for calibration in lCalibration: + reg, curScore = self.findBestMatch(calibration,self.getXCuts(),pageObject.lf_XCut) +# print calibration, reg, curScore + if curScore > bestScore: + bestReg=reg;bestScore=curScore - return lFinres,None,score1 + if bestReg: + ltmp = self.getXCuts()[:] + ltmp.append('EMPTY') + lMissingIndex = filter(lambda x: x not in bestReg, range(0,len(self.getXCuts())+1)) + lMissing = np.array(ltmp)[lMissingIndex].tolist() + lMissing = filter(lambda x: x!= 'EMPTY',lMissing) + result = np.array(ltmp)[bestReg].tolist() + lFinres= filter(lambda (x,y): x!= 'EMPTY',zip(result,pageObject.lf_XCut)) + if lFinres != []: + lFinres = self.selectBestCandidat(lFinres) + # for estimating missing? + # self.selectBestAnchor(lFinres) + return lFinres,lMissing,self.computeScore(lFinres, pageObject.lf_XCut) else: - return None,None,-1 + return None,None,-1 + + + diff --git a/TranskribusDU/ObjectModel/xmlDSDocumentClass.py b/TranskribusDU/ObjectModel/xmlDSDocumentClass.py index 107ea24..ebc7078 100644 --- a/TranskribusDU/ObjectModel/xmlDSDocumentClass.py +++ b/TranskribusDU/ObjectModel/xmlDSDocumentClass.py @@ -16,7 +16,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))) from .xmlDocumentClass import XMLDocument -# from .XMLDSObjectClass import XMLDSObjectClass +from .XMLDSObjectClass import XMLDSObjectClass from .XMLDSPageClass import XMLDSPageClass from config import ds_xml_def as ds_xml @@ -35,7 +35,7 @@ def __init__(self,domDoc=None): self.currentlPages = [] self.nbTotalPages = 0 XMLDocument.__init__(self) - self.lTags= ['PARAGRAPH','COLUMN','TABLE','REGION','BLOCK','BOX',ds_xml.sLINE_Elt,ds_xml.sTEXT,'BASELINE','GRAPHELT','SeparatorRegion'] + def addPage(self,p): self.lPages.append(p) @@ -58,7 +58,7 @@ def loadPages(self,domdocRoot,myTAG,lPages): myPage.setNumber(int(page.get('number'))) self.addPage(myPage) # self.getRootObject()._lObjects = self.getPages() - myPage.fromDom(page,self.lTags) + myPage.fromDom(page,['COLUMN','TABLE','REGION','BLOCK',ds_xml.sLINE_Elt,ds_xml.sTEXT,'BASELINE','GRAPHELT','SeparatorRegion']) def loadFromDom(self,docDom = None,pageTag='PAGE',listPages = []): diff --git a/TranskribusDU/TranskribusDU_version.py b/TranskribusDU/TranskribusDU_version.py index 99ea55a..fbc8220 100644 --- a/TranskribusDU/TranskribusDU_version.py +++ b/TranskribusDU/TranskribusDU_version.py @@ -1,6 +1,8 @@ ''' Created on 30 Nov 2016 +Updated 2019-11-20 by JL Meunier + @author: meunier ''' -version="0.19" +version="0.20" diff --git a/TranskribusDU/common/Component.py b/TranskribusDU/common/Component.py index 34f09a5..19b4854 100644 --- a/TranskribusDU/common/Component.py +++ b/TranskribusDU/common/Component.py @@ -828,44 +828,43 @@ def testCompare_InfoExtraction(self, refdoc, rundoc, sXpathExpr, funCompare=None i = 0 ltisRefsRunbErrbMiss = list() - try: - while True: - i += 1 - bErr, bMiss = False, False - ref, run = itreflLen.next().text, itrunlLen.next().text - if funNormalize: - srunnorm = funNormalize(run) #using also our normalization in addition to the standard one - srefnorm = funNormalize(ref) + + for ref,run in zip(itreflLen,itrunlLen): + ref=ref.text + run=run.TExt + i += 1 + bErr, bMiss = False, False + #ref, run = itreflLen.next().text, itrunlLen.next().text + if funNormalize: + srunnorm = funNormalize(run) #using also our normalization in addition to the standard one + srefnorm = funNormalize(ref) + else: + srunnorm, srefnorm = run, ref + #traceln((i, ref, run)) + if run: #it found something + if funCompare: + bOk = funCompare(srunnorm, srefnorm) else: - srunnorm, srefnorm = run, ref - #traceln((i, ref, run)) - if run: #it found something - if funCompare: - bOk = funCompare(srunnorm, srefnorm) - else: - bOk = (srunnorm == srefnorm) - if bOk: - nok += 1 - if bVisual: traceln(" *** : page %d: '%s' got '%s'"%(i, ref, run)) - else: - nerr += 1 #something wrong - bErr = True - if ref: - nmiss += 1 #but this is also a miss - bMiss = True - if bVisual: traceln(" : page %d: '%s' expected but got *** '%s'"%(i, ref, run)) - else: #it found nothing - if ref: - nmiss += 1 #it missed something + bOk = (srunnorm == srefnorm) + if bOk: + nok += 1 + if bVisual: traceln(" *** : page %d: '%s' got '%s'"%(i, ref, run)) + else: + nerr += 1 #something wrong + bErr = True + if ref: + nmiss += 1 #but this is also a miss bMiss = True - if bVisual: traceln(" : page %d: '%s' expected ***"%(i, ref)) - else: - nok += 1 #just fine!! - ltisRefsRunbErrbMiss.append( (i, ref, run, bErr, bMiss) ) + if bVisual: traceln(" : page %d: '%s' expected but got *** '%s'"%(i, ref, run)) + else: #it found nothing + if ref: + nmiss += 1 #it missed something + bMiss = True + if bVisual: traceln(" : page %d: '%s' expected ***"%(i, ref)) + else: + nok += 1 #just fine!! + ltisRefsRunbErrbMiss.append( (i, ref, run, bErr, bMiss) ) - except StopIteration: - pass -# refctxt.xpathFreeContext(), runctxt.xpathFreeContext() assert len(reflpnum) == len(runlpnum), "***** ERROR: inconsistent ref (%d) and run (%d) lengths. *****"%(len(reflpnum), len(runlpnum)) diff --git a/TranskribusDU/common/LabelBinarizer2.py b/TranskribusDU/common/LabelBinarizer2.py index 219749c..4f4f587 100644 --- a/TranskribusDU/common/LabelBinarizer2.py +++ b/TranskribusDU/common/LabelBinarizer2.py @@ -9,18 +9,7 @@ Copyright NAVER(C) 2019 Jean-Luc Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/common/TestReport.py b/TranskribusDU/common/TestReport.py index 8ec17f0..85c1ea0 100644 --- a/TranskribusDU/common/TestReport.py +++ b/TranskribusDU/common/TestReport.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/common/XmlConfig.py b/TranskribusDU/common/XmlConfig.py index 850ab73..9d21a3e 100644 --- a/TranskribusDU/common/XmlConfig.py +++ b/TranskribusDU/common/XmlConfig.py @@ -6,6 +6,10 @@ Copyright Xerox XRCE 2006 """ + + + + import config.ds_xml_def as ds_xml from lxml import etree diff --git a/TranskribusDU/common/chrono.py b/TranskribusDU/common/chrono.py index 621d9f2..5113b4c 100644 --- a/TranskribusDU/common/chrono.py +++ b/TranskribusDU/common/chrono.py @@ -57,7 +57,20 @@ def chronoOff(expected_name=None): name, expected_name, [_n for _,_n in ltChrono]) return c.off() - +def pretty_time_delta(seconds): + seconds = int(seconds) + days, seconds = divmod(seconds, 86400) + hours, seconds = divmod(seconds, 3600) + minutes, seconds = divmod(seconds, 60) + if days > 0: + return '%dd %dh %dmin %ds' % (days, hours, minutes, seconds) + elif hours > 0: + return '%dh %dmin %ds' % (hours, minutes, seconds) + elif minutes > 0: + return '%dmin %ds' % (minutes, seconds) + else: + return '%ds' % (seconds,) + #---------- SELF-TEST -------------- if __name__ == "__main__": diff --git a/TranskribusDU/common/trace.py b/TranskribusDU/common/trace.py index 21b0cfa..daadfcd 100644 --- a/TranskribusDU/common/trace.py +++ b/TranskribusDU/common/trace.py @@ -4,6 +4,8 @@ # JL Meunier - May 2004 # Copyright XRCE, 2004 # + + import sys global traceFD diff --git a/TranskribusDU/contentProcessing/taggerChrono.py b/TranskribusDU/contentProcessing/taggerChrono.py index 69203ea..1265dce 100644 --- a/TranskribusDU/contentProcessing/taggerChrono.py +++ b/TranskribusDU/contentProcessing/taggerChrono.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """ - taggerChrono.py task: recognition of a chronological sequence of dates @@ -11,20 +10,6 @@ copyright Naver labs Europe 2018 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - - Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. diff --git a/TranskribusDU/contentProcessing/taggerIEmerge.py b/TranskribusDU/contentProcessing/taggerIEmerge.py index 5efff49..03485ff 100644 --- a/TranskribusDU/contentProcessing/taggerIEmerge.py +++ b/TranskribusDU/contentProcessing/taggerIEmerge.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """ - taggerIEmerge.py task: recognition of multi ouptuts classes @@ -11,20 +10,6 @@ copyright Naver labs Europe 2018 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - - Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. diff --git a/TranskribusDU/contentProcessing/taggerTrainKeras.py b/TranskribusDU/contentProcessing/taggerTrainKeras.py index 1268569..7068db9 100644 --- a/TranskribusDU/contentProcessing/taggerTrainKeras.py +++ b/TranskribusDU/contentProcessing/taggerTrainKeras.py @@ -11,18 +11,7 @@ copyright Naverlabs 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -726,9 +715,6 @@ def prepareOutput(self,lToken, lTags): lRes.append((toffset,tok,label,list(lScore))) return lRes - - - def predict_multiptype(self,lsent): """ @@ -774,47 +760,6 @@ def predict_multiptype(self,lsent): return lRes - def predict_multiptype_list(self,lsent): - """ - predict over a set of sentences (unicode) - """ - - lRes= [] - allwords=[] - for mysent in lsent : - # print self.tag_vector - if len(mysent.split())> self.max_sentence_len: -# print ('max sent length: %s'%self.max_sentence_len) - continue - allwords.extend(self.node_transformer.transform(mysent.split())) - wordsvec = [] - for w in allwords: - wordsvec.append(w) - lX = list() - nil_X = np.zeros(self.max_features) - pad_length = self.max_sentence_len - len(wordsvec) - lX.append( wordsvec +((pad_length)*[nil_X]) ) - lX=np.array(lX) - y_pred1,y_pred2 = self.model.predict(lX) - for i,_ in enumerate(lX): -# pred_seq = y_pred[i] - l_multi_type_results = [] - for pred_seq in [y_pred1[i],y_pred2[i]]: - pred_tags = [] - pad_length = self.max_sentence_len - len(allwords) - for class_prs in pred_seq: - class_vec = np.zeros(self.nbClasses, dtype=np.int32) - class_vec[ np.argmax(class_prs) ] = 1 -# print class_prs[class_prs >0.1] - if tuple(class_vec.tolist()) in self.tag_vector: - #print self.tag_vector[tuple(class_vec.tolist())],class_prs[np.argmax(class_prs)] - pred_tags.append((self.tag_vector[tuple(class_vec.tolist())],class_prs[np.argmax(class_prs)])) - l_multi_type_results.append(pred_tags[:len(allwords)]) -# print(mysent.split(),l_multi_type_results) - lRes.append(self.prepareOutput_multitype(mysent.split(),l_multi_type_results)) - - return lRes - def predict(self,lsent): """ predict over a set of sentences (unicode) diff --git a/TranskribusDU/contentProcessing/taggerTrainKeras2.py b/TranskribusDU/contentProcessing/taggerTrainKeras2.py index 958edfa..b8d8084 100644 --- a/TranskribusDU/contentProcessing/taggerTrainKeras2.py +++ b/TranskribusDU/contentProcessing/taggerTrainKeras2.py @@ -11,18 +11,7 @@ copyright Naverlabs 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/contentProcessing/taggerTrainKeras3.py b/TranskribusDU/contentProcessing/taggerTrainKeras3.py index 83688a9..b152457 100644 --- a/TranskribusDU/contentProcessing/taggerTrainKeras3.py +++ b/TranskribusDU/contentProcessing/taggerTrainKeras3.py @@ -11,18 +11,7 @@ copyright Naverlabs 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/crf/Model_SSVM_AD3.py b/TranskribusDU/crf/Model_SSVM_AD3.py index 0c25a55..db3acb9 100644 --- a/TranskribusDU/crf/Model_SSVM_AD3.py +++ b/TranskribusDU/crf/Model_SSVM_AD3.py @@ -8,18 +8,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -47,7 +36,7 @@ from common.chrono import chronoOn, chronoOff from common.TestReport import TestReport -from graph.GraphModel import GraphModel +from graph.GraphModel import GraphModel, GraphModelNoEdgeException from graph.Graph import Graph from crf.OneSlackSSVM import OneSlackSSVM @@ -185,6 +174,13 @@ def train(self, lGraph_trn, lGraph_vld, bWarmStart=True, expiration_timestamp=No lX , lY = self.get_lX_lY(lGraph_trn) lX_vld , lY_vld = self.get_lX_lY(lGraph_vld) bMakeSlim = not bWarmStart # for warm-start mode, we do not make the model slimer!" + self._computeModelCaracteristics(lX) + traceln("\t\t %s" % self._getNbFeatureAsText()) + if False: + np.set_printoptions(threshold=sys.maxsize) + print(lX[0][0]) + traceln("\t\t %s" % self._getNbFeatureAsText()) + sys.exit(1) traceln("\t- retrieving or creating model...") self.ssvm = None @@ -473,7 +469,7 @@ def testFiles(self, lsFilename, loadFun, bBaseLine=False): for sFilename in lsFilename: lg = loadFun(sFilename) #returns a singleton list for g in lg: - if self.bConjugate: g.computeEdgeLabels() + if g.bConjugate: g.computeEdgeLabels() [X], [Y] = self.get_lX_lY([g]) if lLabelName == None: @@ -496,7 +492,6 @@ def testFiles(self, lsFilename, loadFun, bBaseLine=False): lX .append(X) lY .append(Y) lY_pred.append(Y_pred) - #g.detachFromDOM() del g #this can be very large gc.collect() traceln("[%.1fs] done\n"%chronoOff("testFiles")) @@ -529,6 +524,7 @@ def predict(self, graph, bProba=False): return a numpy array, which is a 1-dim array of size the number of nodes of the graph. """ [X] = self.get_lX([graph]) + if X[1].shape[0] == 0: raise GraphModelNoEdgeException # no edge in this graph! bConstraint = graph.getPageConstraint() traceln("\t\t #nodes=%d #edges=%d "%Graph.getNodeEdgeTotalNumber([graph])) self._computeModelCaracteristics([X]) #we discover here dynamically the number of features of nodes and edges diff --git a/TranskribusDU/crf/Model_SSVM_AD3_Multitype.py b/TranskribusDU/crf/Model_SSVM_AD3_Multitype.py index 4e9d547..40ba13d 100644 --- a/TranskribusDU/crf/Model_SSVM_AD3_Multitype.py +++ b/TranskribusDU/crf/Model_SSVM_AD3_Multitype.py @@ -8,18 +8,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/crf/Model_SSVM_AD3_main.py b/TranskribusDU/crf/Model_SSVM_AD3_main.py index dda3875..1b68bd2 100644 --- a/TranskribusDU/crf/Model_SSVM_AD3_main.py +++ b/TranskribusDU/crf/Model_SSVM_AD3_main.py @@ -8,18 +8,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/crf/OneSlackSSVM.py b/TranskribusDU/crf/OneSlackSSVM.py index 6bab03f..67b674c 100644 --- a/TranskribusDU/crf/OneSlackSSVM.py +++ b/TranskribusDU/crf/OneSlackSSVM.py @@ -6,18 +6,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/crf/__init__.py b/TranskribusDU/crf/__init__.py index 7fc41cb..fee883f 100644 --- a/TranskribusDU/crf/__init__.py +++ b/TranskribusDU/crf/__init__.py @@ -7,18 +7,7 @@ Copyright Xerox(C) 2016 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/dataGenerator/dataAugmentation.py b/TranskribusDU/dataGenerator/dataAugmentation.py index 5fa96ca..1aed639 100644 --- a/TranskribusDU/dataGenerator/dataAugmentation.py +++ b/TranskribusDU/dataGenerator/dataAugmentation.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- """ - - Data Augmentation: generate Layout annotated data @@ -9,20 +7,6 @@ copyright Naver Labs 2019 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - - Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. diff --git a/TranskribusDU/dataGenerator/generator.py b/TranskribusDU/dataGenerator/generator.py index 78a5291..db80268 100644 --- a/TranskribusDU/dataGenerator/generator.py +++ b/TranskribusDU/dataGenerator/generator.py @@ -11,39 +11,26 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + import random import numpy as np -import json class Generator(object): ID=0 +# genID=0 - lClassesToBeLearnt=[] - def __init__(self,config=None): - - self.config=config + def __init__(self): + # structure of the object: list of Generators with alterniatives (all possible structures) self._structure = None @@ -64,10 +51,6 @@ def __init__(self,config=None): self.ID = Generator.ID Generator.ID+=1 - - self.noiseType = 0 - self.noiseLevel = 0 - # nice label for ML (getName is too long) # default name self._label = self.getName() @@ -86,31 +69,14 @@ def __init__(self,config=None): def __str__(self): return self.getName() def __repr__(self): return self.getName() - def setConfig(self,c): self.config = c - def getConfig(self): return self.config - - def getMyConfig(self,label): - try: return self.getConfig()[label] - except KeyError: return None - + def getLabel(self): return self._label def setLabel(self,l): self._label= l - def setClassesToBeLearnt(self,l): - self.lClassesToBeLearnt = l - if self._structure is not None: - for x in [x[0] for se in self._structure for x in se[:-1]]: - x.setClassesToBeLearnt(l) - - - def setNoiseType(self,t): self.noiseType = t - def getNoiseType(self): return self.noiseType - def setNoiseLevel(self,t): self.noiseLevel = t - def getNoiseLevel(self): return self.noiseLevel # def getName(self): return "%s_%d"%(self.__class__.__name__ ,self.ID) def getName(self): return self.__class__.__name__ - def getParent(self): return self._parent + def getParent(self):self._parent def setParent(self,p):self._parent = p # when generated by listGenerator def setNumber(self,n): self._number = n @@ -160,27 +126,14 @@ def serialize(self): """ raise Exception('must be instantiated') -# def exportAnnotatedData(self,foo=[]): -# """ -# generate annotated data for self -# build a full version of self._generation: integration of the subparts (subobjects) -# -# """ -# raise Exception( 'must be instantiated',self) - def exportAnnotatedData(self,foo): - """ - build a full version of generation: integration of the subparts (subtree) - - what are the GT annotation for document? - + def exportAnnotatedData(self,foo=[]): """ - ## export (generated value, label) for terminal + generate annotated data for self + build a full version of self._generation: integration of the subparts (subobjects) - self._GT=[] - for obj in self._generation: - self._GT.append((obj.exportAnnotatedData([]),obj)) - - return self._GT + """ + raise Exception( 'must be instantiated') + def instantiate(self): """ @@ -196,6 +149,7 @@ def instantiate(self): if type(struct) in [ tuple,list] : for obj, _,proba in struct: if obj is not None: + obj.setParent(self) generateProb = 1.0 * random.uniform(1,100) if generateProb < proba: self._instance.append(obj.instantiate()) @@ -241,22 +195,6 @@ def noiseSplit(self): """ raise Exception('must be instantiated') - def saveconfig(self,config,filename): - """ - json dump of the config - """ - try: - f = open(filename,"wb",encoding='utf-8') - json.dump(f,config,indent=True) - except IOError:print('not possible to open %s.'%(filename)) - - def loadconfig(self,filename): - try: - f = open(filename,"rb",encoding='utf-8') - return json.load(f) - except IOError:print('not possible to open %s.'%(filename)) - - if __name__ == "__main__": g= Generator() diff --git a/TranskribusDU/dataGenerator/layoutGenerator.py b/TranskribusDU/dataGenerator/layoutGenerator.py index a8d8f32..7919489 100644 --- a/TranskribusDU/dataGenerator/layoutGenerator.py +++ b/TranskribusDU/dataGenerator/layoutGenerator.py @@ -11,39 +11,28 @@ copyright Naver labs Europe 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + from lxml import etree try:basestring except NameError:basestring = str -from dataGenerator.generator import Generator -from dataGenerator.numericalGenerator import numericalGenerator +from .generator import Generator +from .numericalGenerator import numericalGenerator class layoutZoneGenerator(Generator): - def __init__(self,config,x=None,y=None,h=None,w=None): - Generator.__init__(self,config) + def __init__(self,x=None,y=None,h=None,w=None): + Generator.__init__(self) if x is None: self._x = numericalGenerator(None,None) self._x.setLabel("x") @@ -84,10 +73,10 @@ def setHeight(self,v): self._h.setUple(v) def setWidth(self,v): self._w.setUple(v) def setPositionalGenerators(self,x,y,h,w): - if x is not None:self.setX(x) - if y is not None:self.setY(y) - if h is not None:self.setHeight(h) - if w is not None:self.setWidth(w) + self.setX(x) + self.setY(y) + self.setHeight(h) + self.setWidth(w) def setPage(self,p): self._page = p def getPage(self):return self._page @@ -170,7 +159,7 @@ def serialize(self): if __name__ == '__main__': TH=30 - myZone = layoutZoneGenerator({},numericalGenerator(5,TH),numericalGenerator(30,TH),numericalGenerator(20,TH),numericalGenerator(100,TH)) + myZone= layoutZoneGenerator(numericalGenerator(5,TH),numericalGenerator(30,TH),numericalGenerator(20,TH),numericalGenerator(100,TH)) myZone.instantiate() myZone.generate() print(myZone._generation) diff --git a/TranskribusDU/dataGenerator/layoutObjectGenerator.py b/TranskribusDU/dataGenerator/layoutObjectGenerator.py index b033f96..5c85bf1 100644 --- a/TranskribusDU/dataGenerator/layoutObjectGenerator.py +++ b/TranskribusDU/dataGenerator/layoutObjectGenerator.py @@ -2,25 +2,14 @@ """ - Samples of layout generators + layoutGenerator.py generate Layout annotated data copyright Naver Labs 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -29,9 +18,9 @@ @author: H. Déjean """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + try:basestring @@ -44,11 +33,10 @@ from dataGenerator.numericalGenerator import numericalGenerator from dataGenerator.numericalGenerator import integerGenerator -from dataGenerator.numericalGenerator import positiveIntegerGenerator from dataGenerator.generator import Generator from dataGenerator.layoutGenerator import layoutZoneGenerator from dataGenerator.listGenerator import listGenerator -from dataGenerator.typoGenerator import horizontalTypoGenerator,verticalTypoGenerator +# from booleanGenerator import booleanGenerator class doublePageGenerator(layoutZoneGenerator): """ @@ -66,26 +54,14 @@ class doublePageGenerator(layoutZoneGenerator): for page break: need to work with content generator? """ - def __init__(self,config): - """ - "page":{ - "scanning": None, - "pageH": (780, 50), - "pageW": (1000, 50), - "nbPages": (nbpages,0), - "lmargin": tlMarginGen, - "rmargin": trMarginGen, - 'pnum' :True, - "pnumZone": 0, - "grid" : tGrid - """ - layoutZoneGenerator.__init__(self,config) - self.leftPage = pageGenerator(config) - self.leftPage.setLeftOrRight(1) - self.leftPage.setParent(self) - self.rightPage = pageGenerator(config) - self.rightPage.setLeftOrRight(2) - self.rightPage.setParent(self) + def __init__(self,h,w,m,r,config=None): + layoutZoneGenerator.__init__(self) + + ml,mr= m + self.leftPage = pageGenerator(h,w,ml,r,config) + self.leftPage.leftOrRight = 1 #left + self.rightPage = pageGenerator(h,w,mr,r,config) + self.rightPage.leftOrRight = 2 #right self._structure = [ ((self.leftPage,1,100),(self.rightPage,1,100),100) @@ -97,13 +73,9 @@ class pageGenerator(layoutZoneGenerator): need to add background zone """ ID=1 - def __init__(self,config): - layoutZoneGenerator.__init__(self,config) + def __init__(self,h,w,m,r,dConfig): + layoutZoneGenerator.__init__(self) self._label='PAGE' - h=config['page']['pageH'] - w=config['page']['pageW'] - r= config['page']["grid"] - hm,hsd= h self.pageHeight = integerGenerator(hm,hsd) self.pageHeight.setLabel('height') @@ -111,6 +83,7 @@ def __init__(self,config): self.pageWidth = integerGenerator(wm,wsd) self.pageWidth.setLabel('width') + self.myConfig=dConfig ##background ##also need X0 and y0 @@ -123,7 +96,7 @@ def __init__(self,config): self.nbcolumns = integerGenerator(cm, cs) self.nbcolumns.setLabel('nbCol') self.gutter = integerGenerator(gm,gs) - self.ColumnsListGen = listGenerator(config,columnGenerator, self.nbcolumns) + self.ColumnsListGen = listGenerator(columnGenerator, self.nbcolumns ,None) self.ColumnsListGen.setLabel("GRIDCOL") # required at line level! @@ -133,12 +106,11 @@ def __init__(self,config): self.leftOrRight = None # WHITE SPACES self.pageNumber = None # should come from documentGen.listGen(page)? - if self.getConfig()['page']['pnum']: - self.pageNumber = pageNumberGenerator(config) - - self._margin = marginGenerator(config) - - + if self.myConfig['pnum']: + self.pageNumber = pageNumberGenerator() + self._margin = marginGenerator(*m) +# self._ruling = gridGenerator(*r) + # need to build margin zones! (for text in margin) # should be replaced by a layoutZoneGenerator self._typeArea_x1 = None @@ -169,8 +141,6 @@ def __init__(self,config): mystruct ] - - def setLeftOrRight(self,n): self.leftOrRight = n def getLeftMargin(self): return self._marginRegions[2] def getRightMargin(self):return self._marginRegions[3] @@ -206,7 +176,7 @@ def computeAllValues(self,H,W,t,b,l,r): def addPageNumber(self,p): """ """ - zoneIndx = self.getConfig()["page"]['pnumZone'] + zoneIndx = self.myConfig['pnumZone'] region = self._marginRegions[zoneIndx] # in the middle of the zone @@ -216,8 +186,7 @@ def generate(self): """ bypass layoutZoneGen: specific to page """ - self.setConfig(self.getParent().getConfig()) - +# self.setNumber(self.getParent().getNumber()) self.setNumber(1) self._generation = [] for obj in self._instance[:2]: @@ -262,12 +231,8 @@ def generate(self): colGen.setPositionalGenerators((colx,5),(coly,5),(colH,5),(colW,5)) colGen.setGrid(self) colGen.setPage(self) - if self.getConfig()['colStruct'][0] == listGenerator: - content=listGenerator(self.getConfig(), self.getConfig()['colStruct'][1],integerGenerator(*self.getConfig()['colStruct'][2])) - else: - content=self.getConfig()['colStruct'][0](self.getConfig()) -# try:content=self.getConfig()['colStruct'][0](self.getConfig()) -# except KeyError as e: content=None + try:content=self.myConfig['colStruct'][0](*self.myConfig['colStruct'][-1]) + except KeyError as e: content=None if content is not None: colGen.updateStructure((content,1,100)) colGen.instantiate() @@ -320,9 +285,22 @@ class columnGenerator(layoutZoneGenerator): see CSS Box Model: margin,border, padding """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) + def __init__(self,x=None,y=None,h=None,w=None): + layoutZoneGenerator.__init__(self,x=x,y=y,h=h,w=w) self.setLabel("COLUMN") + +# # lines or linegrid ?? +# self.nbLines = integerGenerator(40, 5) +# self.nbLines.setLabel('nbLines') +# self.LineListGen = listGenerator(LineGenerator, self.nbLines ,None) +# self.LineListGen.setLabel("colLine") +# self._mygrid = None +# self.leading= 12 + + + # table (+ caption) + #self.fullPageTable = tableGenerator(nbCols,nbRows) + # other elements? image+ caption self._structure = [ # [(self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100),(self.LineListGen,1,100),100] @@ -356,15 +334,13 @@ def generate(self): self._generation.append(colContent) elif isinstance(colContent,listGenerator): - self.leading = integerGenerator(*self.getConfig()['line']['leading']) - self.leading.generate() - self.leading.setLabel('leading') + for i,lineGen in enumerate(colContent._instance): # too many lines - if (i * self.leading._generation) + self.getY()._generation > (self.getY()._generation + self.getHeight()._generation): + if (i * self.leading) + self.getY()._generation > (self.getY()._generation + self.getHeight()._generation): continue linex =self.getX()._generation - liney = (i * self.leading._generation) + self.getY()._generation + liney = (i * self.leading) + self.getY()._generation lineH = 10 lineW = self.getWidth()._generation lineGen.setParent(self) @@ -380,15 +356,15 @@ class pageNumberGenerator(layoutZoneGenerator): """ a pagenumgen """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) + def __init__(self,x=None,y=None,h=None,w=None): + layoutZoneGenerator.__init__(self,x=x,y=y,h=h,w=w) self._label='LINE' def XMLDSFormatAnnotatedData(self, linfo, obj): self.domNode = etree.Element(obj.getLabel()) self.domNode.set('pagenumber','yes') - self.domNode.set('DU_row','O') + self.domNode.set('type','RO') for info,tag in linfo: if isinstance(tag,Generator): node=tag.XMLDSFormatAnnotatedData(info,tag) @@ -405,13 +381,8 @@ class marginGenerator(Generator): restricted to 1?2-column grid max for the moment? """ - def __init__(self,config): - Generator.__init__(self,config) - - top = config['page']["margin"][0][0] - bottom = config['page']["margin"][0][1] - left = config['page']["margin"][0][2] - right = config['page']["margin"][0][3] + def __init__(self,top,bottom,left, right): + Generator.__init__(self) m,sd = top self._top= integerGenerator(m,sd) self._top.setLabel('top') @@ -428,13 +399,13 @@ def __init__(self,config): self._label='margin' - self.leftMarginGen=layoutZoneGenerator(config) + self.leftMarginGen=layoutZoneGenerator() self.leftMarginGen.setLabel('leftMargin') - self.rightMarginGen=layoutZoneGenerator(config) + self.rightMarginGen=layoutZoneGenerator() self.rightMarginGen.setLabel('rightMargin') - self.topMarginGen=layoutZoneGenerator(config) + self.topMarginGen=layoutZoneGenerator() self.topMarginGen.setLabel('topMargin') - self.bottomMarginGen=layoutZoneGenerator(config) + self.bottomMarginGen=layoutZoneGenerator() self.bottomMarginGen.setLabel('bottomMargin') @@ -486,25 +457,12 @@ def XMLDSFormatAnnotatedData(self,linfo,obj): -class catchword(layoutZoneGenerator): - """ - catchword: always bottom right? - """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) - self.setLabel("CATCHWORD") - - self._structure = [ - ((self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100),100) - ] - - class marginaliaGenerator(layoutZoneGenerator): """ marginalia Gen: assume relation with 'body' part """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) + def __init__(self,x=None,y=None,h=None,w=None): + layoutZoneGenerator.__init__(self,x=x,y=y,h=h,w=w) self.setLabel("MARGINALIA") #pointer to the parent structures!! line? page,? @@ -526,43 +484,30 @@ class LineGenerator(layoutZoneGenerator): if parent =... """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) + def __init__(self,x=None,y=None,h=None,w=None): + layoutZoneGenerator.__init__(self,x=x,y=y,h=h,w=w) self.setLabel("LINE") - self._noteGen = None - self._noteGenProb = None - - self.BIES = 'O' - if "marginalia" in self.getConfig()["line"]: - self._noteGen = self.getConfig()["line"]["marginalia"][0](self.getConfig()) - self._noteGenProba= self.getConfig()["line"]["marginalia"][1] - + self.BIES = 'RO' + self._noteGen = marginaliaGenerator() self._justifixationGen = None #justificationGenerator() # center, left, right, just, random - self.bSkew = None # (angle,std) + self.bSkew=None # (angle,std) self._structure = [ +# ((self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100),(self._noteGen,1,010),100) ((self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100),100) - ] - if self._noteGen is not None: - self._structure = [ - ((self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100),(self._noteGen,1,self._noteGenProba),100) - ] - + + ] def setPage(self,p):self._page=p def getPage(self): return self._page def computeBIES(self,pos,nbLines): - """ - new annotation DU_row - """ - - if nbLines == 1 : self.BIES='S' - elif pos == 0 : self.BIES='B' - elif pos == nbLines-1: self.BIES='E' - else : self.BIES='I' + if nbLines == 1 : self.BIES='RS' + elif pos == 0 : self.BIES='RB' + elif pos == nbLines-1: self.BIES='RE' + else : self.BIES='RI' def generate(self): """ @@ -598,16 +543,13 @@ def generate(self): # compute position according to the justifiaction : need parent, self._noteGen.setPositionalGenerators((marginaliax,5),(marginaliay,5),(marginaliaH,5),(marginaliaW,5)) self._noteGen.generate() - self._generation.append(self._noteGen) - + self._generation.append(self._noteGen) return self def XMLDSFormatAnnotatedData(self,linfo,obj): self.domNode = etree.Element(obj.getLabel()) # for listed elements -# self.domNode.set('type',str(self.BIES)) - self.domNode.set('DU_row',str(self.BIES)) - # need DU_col, DU_header + self.domNode.set('type',str(self.BIES)) for info,tag in linfo: if isinstance(tag,Generator): @@ -629,18 +571,18 @@ class cellGenerator(layoutZoneGenerator): """ - def __init__(self,config,x=None,y=None,h=None,w=None): - layoutZoneGenerator.__init__(self,config,x=x,y=y,h=h,w=w) + def __init__(self,x=None,y=None,h=None,w=None): + layoutZoneGenerator.__init__(self,x=x,y=y,h=h,w=w) self.setLabel("CELL") self._index = None # self.VJustification = booleanGenerator(0.1) # self.VJustification.setLabel('VJustification') # self.HJustification = integerGenerator(3, 1) - self.leading = integerGenerator(*self.getConfig()['line']['leading']) + self.leading = integerGenerator(20, 1) self.leading.setLabel('leading') - self.nbLinesG = integerGenerator(5, 3) - self._LineListGen = listGenerator(config,LineGenerator, self.nbLinesG) + self.nbLines = integerGenerator(5, 3) + self._LineListGen = listGenerator(LineGenerator, self.nbLines ,None) self._LineListGen.setLabel("cellline") self._structure =[((self.getX(),1,100),(self.getY(),1,100),(self.getHeight(),1,100),(self.getWidth(),1,100), (self.leading,1,100), @@ -650,21 +592,14 @@ def __init__(self,config,x=None,y=None,h=None,w=None): def getIndex(self): return self._index def setIndex(self,i,j): self._index=(i,j) - def setNbLinesGenerator(self,g): - self.nbLinesG = g - self._LineListGen = listGenerator(self.getConfig(),LineGenerator, self.nbLinesG) - def getNbLinesGenerator(self): return self.nbLinesG - def computeYStart(self,HJustification,blockH): + def computeYStart(self,bVJustification,blockH): """ compute where to start 'writing' according to justification and number of lines (height of the block) """ - if HJustification == horizontalTypoGenerator.TYPO_TOP: - return 0 - if HJustification == horizontalTypoGenerator.TYPO_HCENTER: + if bVJustification: return (0.5 * self.getHeight()._generation) - (0.5 * blockH) - if HJustification == horizontalTypoGenerator.TYPO_BOTTOM: - # not implemented: need the number of lines for this! + else: return 0 def generate(self): @@ -674,44 +609,34 @@ def generate(self): self._generation.append(obj) # print self.getLabel(),self._generation self._LineListGen.instantiate() - - self.vjustification = self.getConfig()['vjustification'].generate()._generation - self.hjustification = self.getConfig()['hjustification'].generate()._generation + # vertical justification : find the y start # ystart=self.computeYStart(self.VJustification._generation, self._LineListGen.getValuedNb()*self.leading._generation) - ystart=self.computeYStart( self.hjustification, self._LineListGen.getValuedNb()*self.leading._generation) - xstart = self.getWidth()._generation * 0.1 - rowPaddingGen = numericalGenerator(1,0) + ystart=self.computeYStart(False, self._LineListGen.getValuedNb()*self.leading._generation) + xstart = self.getWidth()._generation * 0.25 + rowPaddingGen = numericalGenerator(10,2) rowPaddingGen.generate() - lineH=integerGenerator(*self.getConfig()['line']['lineHeight']) - lineH.generate() + lineH = 15 nexty= ystart + self.getY()._generation + rowPaddingGen._generation lLines=[] for i,lineGen in enumerate(self._LineListGen._instance): # too many lines # if (i * self.leading._generation) + (self.getY()._generation + lineH) > (self.getY()._generation + self.getHeight()._generation): - if nexty +lineH._generation > (self.getY()._generation + self.getHeight()._generation): + if nexty +lineH > (self.getY()._generation + self.getHeight()._generation): continue - + ## centered by default? + linex = self.getX()._generation + (xstart) liney = nexty - lineW=integerGenerator(self.getWidth()._generation*0.75,self.getWidth()._generation*0.1) - lineW.generate() - - if self.vjustification == verticalTypoGenerator.TYPO_LEFT: - linex = self.getX()._generation + (xstart) - if self.vjustification == verticalTypoGenerator.TYPO_RIGHT: - linex = self.getX()._generation + self.getWidth()._generation - lineW._generation - elif self.vjustification == verticalTypoGenerator.TYPO_VCENTER: - linex = self.getX()._generation + self.getWidth()._generation * 0.5 - lineW._generation *0.5 - lineGen.setPositionalGenerators((linex,1),(liney,1),(lineH._generation,0.5),(lineW._generation,0)) + lineW = self.getWidth()._generation + lineGen.setPositionalGenerators((linex,5),(liney,5),(lineH,5),(lineW * 0.5,lineW * 0.1)) # lineGen.setPositionalGenerators((linex,0),(liney,0),(lineH,0),(lineW * 0.5,lineW * 0.1)) lineGen.setPage(self.getPage()) lineGen.setParent(self) lLines.append(lineGen) lineGen.generate() rowPaddingGen.generate() - nexty= lineGen.getY()._generation +self.leading._generation + lineGen.getHeight()._generation+ rowPaddingGen._generation + nexty= lineGen.getY()._generation + lineGen.getHeight()._generation+ rowPaddingGen._generation lineGen.setLabel('LINE') self._generation.append(lineGen) @@ -747,29 +672,18 @@ class tableGenerator(layoutZoneGenerator): or rows/column height/width (or constraint = allthesamevalue) """ - def __init__(self,config): - layoutZoneGenerator.__init__(self,config) - + def __init__(self,nbCols,nbRows): + layoutZoneGenerator.__init__(self) self.setLabel('TABLE') - - nbRows=config['table']['nbRows'] - self.rowHeightVariation = config['table']['rowHeightVariation'] - self.rowHStd=self.rowHeightVariation[1] - self.columnWidthVariation = config['table']['columnWidthVariation'] - - if 'widths' in self.getConfig()['table']['column']: - self.nbCols = integerGenerator(len(self.getConfig()['table']['column']['widths']),0) - else: - nbCols=config['table']['nbCols'] - self.nbCols = integerGenerator(nbCols[0],nbCols[1]) + self.nbCols = integerGenerator(nbCols[0],nbCols[1]) self.nbCols.setLabel('nbCols') self.nbRows = integerGenerator(nbRows[0],nbRows[1]) self.nbRows.setLabel('nbRows') - self._bSameRowHeight=config['table']['row']['sameRowHeight'] - self._lRowsGen = listGenerator(config,layoutZoneGenerator, self.nbRows) + self._bSameRowHeight=True + self._lRowsGen = listGenerator(layoutZoneGenerator, self.nbRows ,None) self._lRowsGen.setLabel("row") - self._lColumnsGen = listGenerator(config['table']['column'],layoutZoneGenerator, self.nbCols ) + self._lColumnsGen = listGenerator(layoutZoneGenerator, self.nbCols ,None) self._lColumnsGen.setLabel("col") self._structure = [ @@ -797,28 +711,23 @@ def generate(self): self._columnWidthG = numericalGenerator(self._columnWidthM, self._columnWidthM*0.2) self._rowHeightM = int(round(self.getHeight()._generation / nbRows)) - self._rowHeightG = positiveIntegerGenerator(self._rowHeightM,self.rowHStd) + self._rowHeightG = numericalGenerator(self._rowHeightM,self._rowHeightM*0.25) -# self._rowHeightM = int(round(self.getHeight()._generation / nbRows)) -# self._rowHeightG = numericalGenerator(self._rowHeightM,self._rowHeightM*0.5) self.lCols=[] self.lRows=[] nextx= self.getX()._generation - for i,colGen in enumerate(self._lColumnsGen._instance): if nextx > self.getX()._generation + self.getWidth()._generation: continue + self._columnWidthG.generate() + colx = nextx #self.getX()._generation + ( i * self._columnWidth) coly = self.getY()._generation - colH = self.getHeight()._generation - if 'widths' in self.getConfig()['table']['column']: - colW = self.getConfig()['table']['column']['widths'][i] * self.getWidth()._generation - else: - self._columnWidthG.generate() - colW = self._columnWidthG._generation + colH = self.getHeight()._generation + colW = self._columnWidthG._generation colGen.setNumber(i) - colGen.setPositionalGenerators((colx,0),(coly,0),(colH,0),(colW,0)) + colGen.setPositionalGenerators((colx,5),(coly,5),(colH,5),(colW,5)) # colGen.setGrid(self) colGen.setLabel("COL") colGen.setPage(self.getPage()) @@ -827,17 +736,8 @@ def generate(self): self._generation.append(colGen) self.lCols.append(colGen) - ## ROW - # max nblines - if 'nbLines' in self.getConfig()['table']['column']: - nbMaxLines = max(x[0] for x in self.getConfig()['table']['column']['nbLines']) - lineH=self.getConfig()['line']['lineHeight'] - lineHG=positiveIntegerGenerator(*lineH) - lineHG.generate() - nblineG=positiveIntegerGenerator(nbMaxLines,0) - nblineG.generate() - self._rowHeightG = positiveIntegerGenerator(nblineG._generation*lineHG._generation,self.rowHStd) - else: nbMaxLines=None + ## here + rowH = None nexty = self.getY()._generation for i,rowGen in enumerate(self._lRowsGen._instance): @@ -852,21 +752,19 @@ def generate(self): else: self._rowHeightG.generate() rowH = self._rowHeightG._generation -# print (rowH) rowy = nexty - # here test that that there is enough space for the row!! + # here test that that there is anough space for the row!! # print self._rowHeightM, self._rowHeightG._generation rowW = self.getWidth()._generation rowGen.setLabel("ROW") rowGen.setNumber(i) rowGen.setPage(self.getPage()) - rowGen.setPositionalGenerators((rowx,0),(rowy,0),(rowH,0),(rowW,0)) + rowGen.setPositionalGenerators((rowx,1),(rowy,1),(rowH,1),(rowW,1)) rowGen.generate() nexty = rowGen.getY()._generation + rowGen.getHeight()._generation # print i, rowy, self.getHeight()._generation self.lRows.append(rowGen) - self._generation.append(rowGen) -# print("%d %s %f"%(i,self._bSameRowHeight,rowGen.getHeight()._generation)) + self._generation.append(rowGen) ## table specific stuff ## table headers, stub,.... @@ -880,34 +778,18 @@ def generate(self): ## creation of the cells; then content in the cells self.lCellGen=[] for icol,col in enumerate(self.lCols): - if 'nbLines' in self.getConfig()['table']['column']: - nblines=self.getConfig()['table']['column']['nbLines'][icol] - nbLineG = positiveIntegerGenerator(*nblines) - else: nbLineG=None for irow, row in enumerate(self.lRows): - cell=cellGenerator(self.getConfig()) + cell=cellGenerator() cell.setLabel("CELL") cell.setPositionalGenerators((col.getX()._generation,0),(row.getY()._generation,0),(row.getHeight()._generation,0),(col.getWidth()._generation,0)) - # colunm header? {'column':{'header':{'colnumber':1,'justification':'centered'}} - - if irow < self.getConfig()['table']['column']['header']['colnumber'] : - cell.getConfig()['vjustification']= self.getConfig()['table']['column']['header']['vjustification'] -# print(icol,cell.getConfig()['justification']) - else:cell.getConfig()['vjustification'] = self.getConfig()['line']['vjustification'] - cell.getConfig()['hjustification'] = self.getConfig()['line']['hjustification'] - # row header? self.lCellGen.append(cell) - cell.setNbLinesGenerator(nbLineG) - cell.setIndex(irow,icol) cell.instantiate() cell.setPage(self.getPage()) cell.generate() + cell.setIndex(irow,icol) self._generation.append(cell) - - - - + class documentGenerator(Generator): """ a document generator @@ -947,17 +829,12 @@ class documentGenerator(Generator): levels between document and page/double-page: usefull? """ - def __init__(self,dConfig): - -# tpageH = dConfig["page"]['pageH'] -# tpageW = dConfig["page"]['pageW'] - tnbpages = dConfig["page"]['nbPages'] -# tMargin = (dConfig["page"]['lmargin'],dConfig["page"]['rmargin']) -# tRuling = dConfig["page"]['grid'] + def __init__(self,dConfig,tpageH,tpageW,tnbpages,tMargin=None,tRuling=None): Generator.__init__(self) self._name = 'DOC' + self.myConfig = dConfig # missing elements: self._isCropped = False # cropped pages self._hasBackground = False # background is visible @@ -976,11 +853,11 @@ def __init__(self,dConfig): self._nbpages = integerGenerator(tnbpages[0],tnbpages[1]) self._nbpages.setLabel('nbpages') -# self._margin = tMargin -# self._ruling = tRuling + self._margin = tMargin + self._ruling = tRuling - self.pageListGen = listGenerator(dConfig,pageGenerator,self._nbpages) + self.pageListGen = listGenerator(pageGenerator,self._nbpages,tpageH,tpageW,self._margin,self._ruling) self.pageListGen.setLabel('pages') self._structure = [ #firstSofcover (first and second) @@ -1013,7 +890,7 @@ def XMLDSFormatAnnotatedData(self,gtdata): # self.docDom.setRootElement(root) metadata= etree.Element("METADATA") root.append(metadata) - metadata.text = str(self.getConfig()) + metadata.text = str(self.myConfig) for info,page in gtdata: pageNode = page.XMLDSFormatAnnotatedData(info,page) root.append(pageNode) @@ -1027,7 +904,7 @@ def generate(self): for i,pageGen in enumerate(self.pageListGen._instance): #if double page: start with 1 = right? - pageGen.setConfig(self.getConfig()) + pageGen.myConfig= self.myConfig pageGen.generate() self._generation.append(pageGen) @@ -1052,24 +929,21 @@ def exportAnnotatedData(self,foo): class DocMirroredPages(documentGenerator): # def __init__(self,tpageH,tpageW,tnbpages,tMargin=None,tRuling=None): def __init__(self,dConfig): - + # scanning = dConfig['scanning'] -# tpageH = dConfig["page"]['pageH'] -# tpageW = dConfig["page"]['pageW'] - tnbpages = dConfig["page"]['nbPages'] - self._nbpages = integerGenerator(tnbpages[0],tnbpages[1]) - self._nbpages.setLabel('nbpages') -# tMargin = (dConfig["page"]['margin'],dConf) -# tRuling = dConfig["page"]['grid'] + tpageH = dConfig['pageH'] + tpageW = dConfig['pageW'] + tnbpages = dConfig['nbPages'] + tMargin = (dConfig['lmargin'],dConfig['rmargin']) + tRuling = dConfig['grid'] -# documentGenerator.__init__(self,tpageH,tpageW,tnbpages,tMargin,tRuling) - documentGenerator.__init__(self,dConfig) - self.setConfig(dConfig) + documentGenerator.__init__(self,dConfig,tpageH,tpageW,tnbpages,tMargin,tRuling) + self.myConfig = dConfig + self._lmargin, self._rmargin = tMargin + self._ruling= tRuling + self.pageListGen = listGenerator(doublePageGenerator,self._nbpages,tpageH,tpageW,(self._lmargin,self._rmargin),self._ruling,dConfig) -# self._lmargin, self._rmargin = tMargin -# self._ruling= tRuling - self.pageListGen = listGenerator(dConfig,doublePageGenerator,self._nbpages) self.pageListGen.setLabel('pages') self._structure = [ #firstSofcover (first and second) @@ -1138,6 +1012,14 @@ def docm(): tGrid = ( 'regular',(2,0),(0,0) ) + #self.nbLines = integerGenerator(40, 5) +# self.nbLines.setLabel('nbLines') +# self.LineListGen = listGenerator(LineGenerator, self.nbLines ,None) +# self.LineListGen.setLabel("colLine") +# self._mygrid = None +# self.leading= 12 + + Config = { "scanning": pageScanning, "pageH": (700, 10), @@ -1160,256 +1042,39 @@ def docm(): # print etree.tostring(docDom,encoding="utf-8", pretty_print=True) docDom.write("tmp.ds_xml",encoding='utf-8',pretty_print=True) -def StAZHDataset(nbpages): +def tableDataset(nbpages): """ - page header (centered) - page number (mirrored: yes and no) - catch word (bottom right) - marginalia (left margin; mirrored also?) - - """ - tlMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - trMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - - tGrid = ( 'regular',(1,0),(0,0) ) - - Config = { - "page":{ - "scanning": None - ,"pageH": (780, 50) - ,"pageW": (500, 50) - ,"nbPages": (nbpages,0) - ,"margin": [tlMarginGen, trMarginGen] - ,'pnum' :{'position':"left"} - ,"pnumZone": 0 - ,"grid" : tGrid - } - #column? - ,"line":{ - "leading": (15+5,1) - ,"lineHeight": (15,1) - ,"justification":'left' - ,'marginalia':[marginaliaGenerator,10] - ,'marginalialineHeight':10 - } - - ,"colStruct": (listGenerator,LineGenerator,(20,0)) -# ,'table':{ -# "nbRows": (40,0) -# ,"nbCols": (5,0) -# ,"rowHeightVariation":(0,0) -# ,"columnWidthVariation":(0,0) -# ,'column':{'header':{'colnumber':1,'justification':'centered'}} -# ,'row':{"sameRowHeight": True } -# ,'cell':{'justification':'right','line':{"leading":(14,0)}} -# } - } - mydoc = DocMirroredPages(Config) - mydoc.instantiate() - mydoc.generate() - gt = mydoc.exportAnnotatedData(()) -# print gt - docDom = mydoc.XMLDSFormatAnnotatedData(gt) - return docDom - - -def ABPRegisterDataset(nbpages): - """ - ABP register - + keep all parameters in the synthetic object!! """ - tlMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - trMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) + tlMarginGen = ((50, 5),(50, 5),(50, 10),(50, 10)) + trMarginGen = ((50, 5),(50, 5),(50, 10),(50, 10)) tGrid = ( 'regular',(1,0),(0,0) ) - # should be replaced by an object? - ABPREGConfig = { - "page":{ - "scanning": None - ,"pageH": (780, 50) - ,"pageW": (500, 50) - ,"nbPages": (nbpages,0) - ,"margin": [tlMarginGen, trMarginGen] - ,'pnum' :{'position':"left"} # also ramdom? - ,"pnumZone": 0 - ,"grid" : tGrid - } - #column? - ,"line":{ - "leading": (5,4) - ,"lineHeight": (18,2) - ,"justification":'left' - } - - ,"colStruct": (tableGenerator,1,nbpages) - ,'table':{ - "nbRows": (30,2) - ,"nbCols": (5,1) - ,"rowHeightVariation":(0,0) - ,"columnWidthVariation":(0,0) - ,'column':{'header':{'colnumber':1,'justification':'centered'}} - ,'row':{"sameRowHeight": True } - ,'cell':{'justification':'right','line':{"leading":(14,0)}} - } - } - - Config=ABPREGConfig - mydoc = DocMirroredPages(Config) - mydoc.instantiate() - mydoc.generate() - gt = mydoc.exportAnnotatedData(()) -# print gt - docDom = mydoc.XMLDSFormatAnnotatedData(gt) - return docDom - -def NAFDataset(nbpages): - tlMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - trMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - - tGrid = ( 'regular',(1,0),(0,0) ) - #for NAF!: how to get the column width??? - NAFConfig = { - "page":{ - "scanning": None, - "pageH": (780, 50), - "pageW": (500, 50), - "nbPages": (nbpages,0), - "margin": [tlMarginGen, trMarginGen], - 'pnum' :True, - "pnumZone": 0, - "grid" : tGrid - } - #column? - ,"line":{ - "leading": (5,4) - ,"lineHeight": (10,1) - ,"justification":'left' - } - - ,"colStruct": (tableGenerator,1,nbpages) - ,'table':{ - "nbRows": (35,10) - ,"nbCols": (5,0) - ,"rowHeightVariation":(20,5) - ,"columnWidthVariation":(0,0) - # proportion of col width known - ,'column':{'header':{'colnumber':1,'justification':'centered'} - ,'widths':(0.01,0.05,0.05,0.5,0.2,0.05,0.05,0.05,0.05,0.05,0.05) - #nb textlines - ,'nbLines':((1,0.1),(1,0.1),(1,0.1),(4,1),(3,1),(1,1),(1,0.5),(1,1),(1,0.5),(1,0.5),(1,0.5)) - - } - ,'row':{"sameRowHeight": False } - ,'cell':{'justification':'right','line':{"leading":(14,0)}} - } - } - Config=NAFConfig - mydoc = DocMirroredPages(Config) - mydoc.instantiate() - mydoc.generate() - gt = mydoc.exportAnnotatedData(()) -# print gt - docDom = mydoc.XMLDSFormatAnnotatedData(gt) - return docDom - - -def NAHDataset(nbpages): - """ - @todo: need to put H centered lines - """ - - tlMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - trMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - - tGrid = ( 'regular',(1,0),(0,0) ) - #for NAF!: how to get the column width??? - NAFConfig = { - "page":{ - "scanning": None, - "pageH": (780, 50), - "pageW": (500, 50), - "nbPages": (nbpages,0), - "margin": [tlMarginGen, trMarginGen], - 'pnum' :True, - "pnumZone": 0, - "grid" : tGrid - } - #column? - ,"line":{ - "leading": (5,4) - ,"lineHeight": (10,1) - ,"vjustification":verticalTypoGenerator([0.5,0.25,0.25]) - # 0: top - ,'hjustification':horizontalTypoGenerator([0.33,0.33,0.33]) - } - - ,"colStruct": (tableGenerator,1,nbpages) - ,'table':{ - "nbRows": (35,10) - ,"nbCols": (5,0) - ,"rowHeightVariation":(20,5) - ,"columnWidthVariation":(0,0) - # proportion of col width known - ,'column':{'header':{'colnumber':1,'vjustification':verticalTypoGenerator([0,1,0])} - ,'widths':(0.01,0.05,0.05,0.5,0.2,0.05,0.05,0.05,0.05,0.05,0.05) - #nb textlines - ,'nbLines':((1,0.1),(1,0.1),(1,0.1),(4,1),(3,1),(1,1),(1,0.5),(1,1),(1,0.5),(1,0.5),(1,0.5)) - - } - ,'row':{"sameRowHeight": False } - ,'cell':{'hjustification':horizontalTypoGenerator([0.75,0.25,0.0]),'vjustification':verticalTypoGenerator([0,0,1]),'line':{"leading":(14,0)}} - } - } - Config=NAFConfig - mydoc = DocMirroredPages(Config) - mydoc.instantiate() - mydoc.generate() - gt = mydoc.exportAnnotatedData(()) -# print gt - docDom = mydoc.XMLDSFormatAnnotatedData(gt) - return docDom - -def BARDataset(nbpages): - """ - ABP register + Config = { + "scanning": None, + "pageH": (780, 50), + "pageW": (1000, 50), + "nbPages": (nbpages,0), + "lmargin": tlMarginGen, + "rmargin": trMarginGen, + 'pnum' :True, + "pnumZone": 0, + "grid" : tGrid, + "leading": (12,1), + "lineHeight":(10,1), + "colStruct": (tableGenerator,1,nbpages,((9,3),(12,5))) - """ - tlMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - trMarginGen = ((50, 5),(50, 5),(50, 5),(50, 5)) - - tGrid = ( 'regular',(1,0),(0,0) ) - - # should be replaced by an object? - BARConfig = { - "page":{ - "scanning": None, - "pageH": (780, 50), - "pageW": (500, 50), - "nbPages": (nbpages,0), - "margin": [tlMarginGen, trMarginGen], - 'pnum' :True, - "pnumZone": 0, - "grid" : tGrid - } - #column? - ,"line":{ - "leading": (15,1) - ,"lineHeight": (15,1) - ,"justification":'left' - } - - ,"colStruct": (listGenerator,LineGenerator,(2,0)) } - Config=BARConfig mydoc = DocMirroredPages(Config) + mydoc.myConfig = Config mydoc.instantiate() mydoc.generate() gt = mydoc.exportAnnotatedData(()) # print gt docDom = mydoc.XMLDSFormatAnnotatedData(gt) - return docDom + return docDom if __name__ == "__main__": @@ -1421,13 +1086,7 @@ def BARDataset(nbpages): try: nbpages = int(sys.argv[1]) except IndexError as e: nbpages = 1 -# dom1 = ABPRegisterDataset(nbpages) -# dom1 = NAFDataset(nbpages) - dom1 = NAHDataset(nbpages) - -# dom1 = StAZHDataset(nbpages) -# dom1 = BARDataset(nbpages) - + dom1 = tableDataset(nbpages) dom1.write(outfile,xml_declaration=True,encoding='utf-8',pretty_print=True) print("saved in %s"%outfile) diff --git a/TranskribusDU/dataGenerator/listGenerator.py b/TranskribusDU/dataGenerator/listGenerator.py index a41a8a5..c4e2f42 100644 --- a/TranskribusDU/dataGenerator/listGenerator.py +++ b/TranskribusDU/dataGenerator/listGenerator.py @@ -11,40 +11,28 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals -from dataGenerator.generator import Generator -try:basestring -except NameError:basestring = str + + +from .generator import Generator + class listGenerator(Generator): """ a generator for list """ - def __init__(self,config,objGen,nbMaxGen): - Generator.__init__(self,config) + def __init__(self,objGen,nbMaxGen,*objParam): + Generator.__init__(self) self.myObjectGen = objGen + self.objParams = objParam self.nbMax = nbMaxGen def getValuedNb(self): return self.nbMax._generation @@ -54,13 +42,7 @@ def instantiate(self): self._instance = [] self.nbMax.generate() for i in range(self.nbMax._generation): - try: - o = self.myObjectGen(self.getConfig()) - o.instantiate() - except TypeError: - o = self.myObjectGen(*self.getConfig()) - o.instantiate() - + o = self.myObjectGen(*self.objParams).instantiate() o.setNumber(i) self._instance.append(o) return self @@ -69,7 +51,7 @@ def exportAnnotatedData(self,foo): self._GT=[] for obj in self._generation: - if type(obj._generation) == basestring: + if type(obj._generation) == unicode: self._GT.append((obj._generation,[obj.getLabel()])) elif type(obj) == int: self._GT.append((obj._generation,[obj.getLabel()])) @@ -80,9 +62,9 @@ def exportAnnotatedData(self,foo): if __name__ == "__main__": - from dataGenerator.numericalGenerator import integerGenerator - integerGenerator(10,0) - lG = listGenerator((5,4),integerGenerator, integerGenerator(10,0)) + from .numericalGenerator import integerGenerator + + lG =listGenerator(integerGenerator,integerGenerator(10,0),5,4) lG.instantiate() lG.generate() print(lG._generation) \ No newline at end of file diff --git a/TranskribusDU/dataGenerator/noiseGenerator.py b/TranskribusDU/dataGenerator/noiseGenerator.py index e868bd6..bbaac9f 100644 --- a/TranskribusDU/dataGenerator/noiseGenerator.py +++ b/TranskribusDU/dataGenerator/noiseGenerator.py @@ -11,27 +11,16 @@ copyright Naver labs Europe 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + from .generator import Generator diff --git a/TranskribusDU/dataGenerator/numericalGenerator.py b/TranskribusDU/dataGenerator/numericalGenerator.py index a8ac2b4..5d6eed7 100644 --- a/TranskribusDU/dataGenerator/numericalGenerator.py +++ b/TranskribusDU/dataGenerator/numericalGenerator.py @@ -11,31 +11,20 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + import random -from dataGenerator.generator import Generator +from .generator import Generator """ see http://www.southampton.ac.uk/~fangohr/blog/physical-quantities-numerical-value-with-units-in-python.html @@ -49,9 +38,9 @@ class numericalGenerator(Generator): """ generic numerical Generator - """ + """ def __init__(self,mean,sd): - Generator.__init__(self,None) + Generator.__init__(self) self._name='num' self._mean = mean self._std= sd @@ -70,11 +59,7 @@ def setUple(self,ms): self._std = s def exportAnnotatedData(self,lLabels): - -# lLabels.append(self.getLabel()) - for i,ltype in enumerate(self.lClassesToBeLearnt): - if self.getName() in ltype: - lLabels[i]= self.getName() + lLabels.append(self.getLabel()) self._GT = [(self._generation,lLabels[:])] return self._GT diff --git a/TranskribusDU/dataGenerator/textGenerator.py b/TranskribusDU/dataGenerator/textGenerator.py index 0dc660a..70085a8 100644 --- a/TranskribusDU/dataGenerator/textGenerator.py +++ b/TranskribusDU/dataGenerator/textGenerator.py @@ -11,27 +11,16 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + import pickle @@ -42,7 +31,7 @@ import random -from dataGenerator.generator import Generator +from .generator import Generator class textGenerator(Generator): @@ -57,7 +46,7 @@ class textGenerator(Generator): def __init__(self,lang): self.lang = lang locale.setlocale(locale.LC_TIME, self.lang) - Generator.__init__(self,{}) + Generator.__init__(self) # list of content self._value = None # reference to the list of content (stored) @@ -122,7 +111,12 @@ def getValue(self): return self._value - +# def instantiate(self): +# """ +# for terminal elements (from a list): nothing to do +# """ +# return [] + def generate(self): """ need to take into account element frequency! done in getRandomElt @@ -134,37 +128,19 @@ def generate(self): while len(self._generation.strip()) == 0: self._generation = self.getRandomElt(self._value) + # create the noiseGenerrator? return self - def delCharacter(self,s,th): + def generateNoise(self): """ - delete characters + add noise to pureGen """ + #use textnoiseGen to determine if noise will be generated? +# if self.getNoiseGen() is not None: - ns="" - for i in range(len(s)): - generateProb = random.uniform(0,100) - if generateProb >= th: - ns+=s[i] - - # at least one char - if ns=="":ns=s[0] - return ns - - def replaceCharacter(self,s,th): - """ - add noise (replace char) to pureGen - """ - ns="" - for i in range(len(s)): - generateProb = random.uniform(0,100) - if generateProb < th: - ns+=chr(int(random.uniform(65,240))) - else: ns+=s[i] - return ns def noiseSplit(self,lGTTokens): """ @@ -192,44 +168,6 @@ def noiseSplit(self,lGTTokens): return lGTTokens - - def TypedBIES(self,lList): - """ - fixed length of types - - ABPRecordGenerator Maria ['PersonName2', 'firstNameGenerator'] - ABPRecordGenerator Pfeiffer ['PersonName2', 'lastNameGenerator'] - ABPRecordGenerator Forster [None, 'professionGenerator'] - - """ - lNewList = [] - for pos,(token,llabels) in enumerate(lList): - # need to copy while we update llabels but need to keep the original version for the prev/next test - lNewList.append((token,llabels[:])) - for type in range(len(self.lClassesToBeLearnt)): - isAsPrev = False - isAsNext = False - bies="??" - if pos > 0: - isAsPrev = llabels[type] ==lList[pos-1][1][type] -# print (llabels[type],lList[pos-1][1][type],llabels[type] == lList[pos-1][1][type]) - if pos < len(lList) -1 : - isAsNext = llabels[type] ==lList[pos+1][1][type] - if isAsPrev and isAsNext: - bies= 'I_' - elif not isAsPrev and not isAsNext: - bies= 'S_' - elif isAsPrev and not isAsNext: - bies= 'E_' - elif not isAsPrev and isAsNext: - bies='B_' - else: - pass - #update - if lNewList[-1][1][type] != None: - lNewList[-1][1][type]= bies+ llabels[type] - return lNewList - def hierarchicalBIES(self,lList): """ add BIES to labels @@ -321,40 +259,6 @@ def tokenizerString(self,s): """ - - def formatFairSeqWord(self,gtdata): - """ - FairSeq Format at character level - C C C C \t BIESO - """ - lnewGT=[] - # duplicate labels for multitoken - for token,label in gtdata: - # should be replace by self.tokenizer(token) - if isinstance(token, str) : #type(token) == unicode: - ltoken = token.split(" ") - elif type(token) in [float,int ]: - ltoken = [token] - - if len(ltoken) == 1: - lnewGT.append((token,label)) - else: - for tok in ltoken: - lnewGT.append((tok,label[:])) - - # compute BIES - assert lnewGT != [] - lnewGT = self.hierarchicalBIES(lnewGT) - - #output for GT - sSource = "" - sTarget = "" - for token, labels in lnewGT: - sTarget += labels[-1] + " " - sSource += str(token) + " " - return sSource, sTarget - - def formatAnnotatedData(self,gtdata,mode=2): """ format with bIES hierarchically @@ -371,43 +275,34 @@ def formatAnnotatedData(self,gtdata,mode=2): if isinstance(token, str) : #type(token) == unicode: ltoken= token.split(" ") elif type(token) in [float,int ]: - ltoken= [str(token)] + ltoken= [token] if len(ltoken) == 1: - # token is a str hereafter - lnewGT.append((str(token),label)) + lnewGT.append((token,label)) else: for tok in ltoken: lnewGT.append((tok,label[:])) # compute BIES assert lnewGT != [] - lnewGT = self.TypedBIES(lnewGT) + lnewGT = self.hierarchicalBIES(lnewGT) # noise here? # lnewGT = self.noiseSplit(lnewGT) #output for GT - sReturn = "" for token, labels in lnewGT: - assert type(token) != int - if len(str(token)) > 0: - uLabels = '\t'.join(labels) - if self.getNoiseType() in [1]: - token = self.delCharacter(token,self.getNoiseLevel()) - uString = "%s\t%s" % (token,uLabels) - sReturn +=uString+'\n' - sReturn+="EOS\n" - return sReturn + uLabels = '\t'.join(labels) + uString = "%s\t%s" % (token,uLabels) + print(uString) + print ("EOS") def exportAnnotatedData(self,lLabels): # export (generated value, label) for terminal self._GT = [] - # here test if the label has to be in the classes to be learned - for i,ltype in enumerate(self.lClassesToBeLearnt): - if self.getName() in ltype: - lLabels[i]=self.getName() + lLabels.append(self.getName()) + if isinstance(self._generation, str) : #type(self._generation) == unicode: self._GT.append((self._generation,lLabels[:])) elif type(self._generation) == int: diff --git a/TranskribusDU/dataGenerator/textNoiseGenerator.py b/TranskribusDU/dataGenerator/textNoiseGenerator.py index ebbee89..22864e1 100644 --- a/TranskribusDU/dataGenerator/textNoiseGenerator.py +++ b/TranskribusDU/dataGenerator/textNoiseGenerator.py @@ -11,32 +11,18 @@ copyright Naver labs Europe 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals -try:basestring -except NameError:basestring = str -from dataGenerator.noiseGenerator import noiseGenerator + + +from .noiseGenerator import noiseGenerator class textNoiseGenerator(noiseGenerator): def __init__(self,tuplesplit, tupleNsplit,tuplechar): @@ -65,7 +51,7 @@ def generate(self,gtdata): for token,label in gtdata: # should be replace by self.tokenizer(token) - if type(token) == basestring: + if type(token) == unicode: ltoken= token.split(" ") elif type(token) in [float,int ]: ltoken= [token] diff --git a/TranskribusDU/dataGenerator/textRandomGenerator.py b/TranskribusDU/dataGenerator/textRandomGenerator.py index 6eeef40..bb922d1 100644 --- a/TranskribusDU/dataGenerator/textRandomGenerator.py +++ b/TranskribusDU/dataGenerator/textRandomGenerator.py @@ -11,31 +11,20 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. """ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import unicode_literals + + + import random import string -from dataGenerator.textGenerator import textGenerator +from .textGenerator import textGenerator class textRandomGenerator(textGenerator): """ @@ -51,14 +40,9 @@ def __init__(self,length,sd): def exportAnnotatedData(self,lLabels): - #lLabels.append(self.getName()) - for i,ltype in enumerate(self.lClassesToBeLearnt): - if self.getName() in ltype: - lLabels[i]=self.getName() - + lLabels.append(self.getName()) self._GT = [(self._generation,lLabels[:])] return self._GT - def generate(self): self._generation="" for i in range(int(round(random.gauss(self._length,self._std)))): @@ -73,7 +57,7 @@ def serialize(self): def noiseSplit(self): textGenerator.noiseSplit(self) -class textletterRandomGenerator(textRandomGenerator): +class textletterstRandomGenerator(textRandomGenerator): def generate(self): self._generation="" for i in range(int(round(random.gauss(self._length,self._std)))): diff --git a/TranskribusDU/dataGenerator/typoGenerator.py b/TranskribusDU/dataGenerator/typoGenerator.py index 995abe7..cef5b75 100644 --- a/TranskribusDU/dataGenerator/typoGenerator.py +++ b/TranskribusDU/dataGenerator/typoGenerator.py @@ -1,30 +1,13 @@ # -*- coding: utf-8 -*- """ - - typoGenerator.py create (generate) textual annotated data H. Déjean - copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - - Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. diff --git a/TranskribusDU/gcn/DU_Model_ECN.py b/TranskribusDU/gcn/DU_Model_ECN.py index 65e2680..4a24725 100644 --- a/TranskribusDU/gcn/DU_Model_ECN.py +++ b/TranskribusDU/gcn/DU_Model_ECN.py @@ -6,18 +6,7 @@ Copyright NAVER(C) 2018, 2019 Stéphane Clinchant, Animesh Prasad Hervé Déjean, Jean-Luc Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -45,7 +34,7 @@ from common.TestReport import TestReport from common.LabelBinarizer2 import LabelBinarizer2 -from graph.GraphModel import GraphModel, GraphModelException +from graph.GraphModel import GraphModel, GraphModelException, GraphModelNoEdgeException from graph.Graph import Graph import gcn.gcn_models as gcn_models from gcn.gcn_datasets import GCNDataset @@ -67,12 +56,18 @@ def __init__(self, sName, sModelDir): # A binarizer that uses 2 columns for binary classes (instead of only 1) self.labelBinarizer = LabelBinarizer2() + def getMetadataComment(self): + s = super().getMetadataComment() + s += "\n" + repr(self.model_config) + return s + @staticmethod def getBaselineConfig(): ''' Return A Baseline Edge Conv Net Configuration with 3 layers and 10 convolutions per layer :return: ''' + assert False, "code to be checked" config={} config['name'] = '3Layers-10conv-stack' config['nb_iter'] = 2000 @@ -99,37 +94,20 @@ def configureLearner(self,**kwargs): #Pass arguments self.model_config=kwargs - ''' - lr=0.001, - stack_convolutions=True, - mu=0.0, - num_layers=3, - node_indim=-1, - nconv_edge=10, - fast_convolve = True, - dropout_rate_edge = 0.0, - dropout_rate_edge_feat = 0.0, - dropout_rate_node = 0.0, - nb_iter=2000, - ratio_train_val=0.15, - activation=tf.nn.tanh, - ): - ''' - def getModelFilename(self): - return os.path.join(self.sDir, self.sName+"."+self.sSurname+".bestmodel.ckpt") + return os.path.join(self.sDir, self.sName+"._."+self.sSurname+".bestmodel.ckpt") def getTmpModelFilename(self): - return os.path.join(self.sDir, self.sName+"."+self.sSurname+".tmpmodel.ckpt") + return os.path.join(self.sDir, self.sName+"._."+self.sSurname+".tmpmodel.ckpt") def getValScoreFilename(self): - return os.path.join(self.sDir, self.sName+"."+self.sSurname+'.validation_scores.pkl') + return os.path.join(self.sDir, self.sName+"._."+self.sSurname+'.validation_scores.pkl') def getlabelBinarizerFilename(self): - return os.path.join(self.sDir, self.sName+"."+self.sSurname+'.label_binarizer.pkl') + return os.path.join(self.sDir, self.sName+"._."+self.sSurname+'.label_binarizer.pkl') def getModelConfigFilename(self): - return os.path.join(self.sDir, self.sName+"."+self.sSurname+'.model_config.pkl') + return os.path.join(self.sDir, self.sName+"._."+self.sSurname+'.model_config.pkl') @classmethod @@ -171,7 +149,7 @@ def get_lX(self, lGraph): # ECN need directed edges, but for the conjugate graph, we proceed differently # because the features remain the same in both directions. - if self.bConjugate: + if g.bConjugate: [(node_features, new_edges, new_edges_feat)] = self.addConjugateRevertedEdge([g], [X]) else: # in conjugate graph mode, we "duplicate" the edges in the conjugate graph @@ -269,6 +247,14 @@ def convert_X_to_GCNDataset(self, X): return graph + def _init_TF_graph(self): + tf_graph = tf.Graph() + with tf_graph.as_default(): + self._init_model() + #This create a session containing the correct model + self.restore() + self.tf_graph = tf_graph + def _init_model(self): ''' Create the tensorflow graph. @@ -284,30 +270,30 @@ def _init_model(self): node_indim=self.model_config['node_indim'], nconv_edge=self.model_config['nconv_edge'], ) - ''' - self.gcn_model.stack_instead_add = self.model_config['stack_convolutions'] - #TODO Clean Constructor - if 'activation' in self.model_config: - self.gcn_model.activation = self.model_config['activation'] - - if 'fast_convolve' in self.model_config: - self.gcn_model.fast_convolve = self.model_config['fast_convolve'] - - if 'dropout_rate_edge' in self.model_config: - self.gcn_model.dropout_rate_edge = self.model_config['dropout_rate_edge'] - print('Dropout Edge', self.gcn_model.dropout_rate_edge) - - if 'dropout_rate_edge_feat' in self.model_config: - self.gcn_model.dropout_rate_edge_feat = self.model_config['dropout_rate_edge_feat'] - print('Dropout Edge', self.gcn_model.dropout_rate_edge_feat) - - if 'dropout_rate_node' in self.model_config: - self.gcn_model.dropout_rate_node = self.model_config['dropout_rate_node'] - print('Dropout Node', self.gcn_model.dropout_rate_node) - ''' self.gcn_model.set_learning_options(self.model_config) self.gcn_model.create_model() +# self.gcn_model.stack_instead_add = self.model_config['stack_convolutions'] +# #TODO Clean Constructor +# if 'activation' in self.model_config: +# self.gcn_model.activation = self.model_config['activation'] +# +# if 'fast_convolve' in self.model_config: +# self.gcn_model.fast_convolve = self.model_config['fast_convolve'] +# +# if 'dropout_rate_edge' in self.model_config: +# self.gcn_model.dropout_rate_edge = self.model_config['dropout_rate_edge'] +# print('Dropout Edge', self.gcn_model.dropout_rate_edge) +# +# if 'dropout_rate_edge_feat' in self.model_config: +# self.gcn_model.dropout_rate_edge_feat = self.model_config['dropout_rate_edge_feat'] +# print('Dropout Edge', self.gcn_model.dropout_rate_edge_feat) +# +# if 'dropout_rate_node' in self.model_config: +# self.gcn_model.dropout_rate_node = self.model_config['dropout_rate_node'] +# print('Dropout Node', self.gcn_model.dropout_rate_node) + + def _cleanTmpCheckpointFiles(self): ''' When a model is trained, tensorflow checkpoint files are created every 10 epochs @@ -326,16 +312,11 @@ def _cleanTmpCheckpointFiles(self): return nb_clean - def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, verbose=0): + def _prepare_for_train(self, lGraph, lGraph_vld): """ - Return a model trained using the given labelled graphs. - The train method is expected to save the model into self.getModelFilename(), at least at end of training - If bWarmStart==True, The model is loaded from the disk, if any, and if fresher than given timestamp, and training restarts - - if some baseline model(s) were set, they are also trained, using the node features - + Prepare for training eCN or EnsembleECN """ - traceln('ECN Training',self.sName) + traceln('ECN Training ', self.sName) traceln("\t- computing features on training set") traceln("\t\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber(lGraph)) chronoOn() @@ -345,10 +326,7 @@ def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, # self._tNF_EF contains the number of node features and edge features traceln("\t\t %s" % self._getNbFeatureAsText()) traceln("\t [%.1fs] done\n" % chronoOff()) - if self.bConjugate: - nb_class = len(lGraph[0].getEdgeLabelNameList()) #Is it better to do Y.shape ? - else: - nb_class = len(lGraph[0].getLabelNameList()) #Is it better to do Y.shape ? + nb_class = len(lGraph[0].getLabelNameList()) #Is it better to do Y.shape ? traceln("\t- %d classes" % nb_class) traceln("\t- retrieving or creating model...") @@ -361,15 +339,6 @@ def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, with open ('linear_reg', 'wb') as save_file: pickle.dump((lX,lY), save_file, pickle.HIGHEST_PROTOCOL) - #This call the ECN internal constructor and defines the tensorflow graph - self._init_model() - - #This call the ECN internal constructor and defines the tensorflow graph - tf_graph=tf.Graph() - with tf_graph.as_default(): - self._init_model() - self.tf_graph=tf_graph - #This converts the lX,lY in the format necessary for GCN Models gcn_graph = self.convert_lX_lY_to_GCNDataset(lX,lY,training=True) @@ -377,10 +346,6 @@ def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, fd_lb =open(self.getlabelBinarizerFilename(),'wb') pickle.dump(self.labelBinarizer,fd_lb) fd_lb.close() - #Save the model config in order to restore the model later - fd_mc = open(self.getModelConfigFilename(), 'wb') - pickle.dump(self.model_config, fd_mc) - fd_mc.close() #TODO Save the validation set too to reproduce experiments random.shuffle(gcn_graph) @@ -392,16 +357,47 @@ def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, del lX_vld, lY_vld else: #Get a validation set from the training set - split_idx = int(self.model_config['ratio_train_val'] * len(gcn_graph)) + split_idx = max(1, int(self.model_config['ratio_train_val'] * len(gcn_graph))) + traceln(" - using %d train graphs as validation graphs" % split_idx) gcn_graph_train = [] gcn_graph_val = [] gcn_graph_val.extend(gcn_graph[:split_idx]) gcn_graph_train.extend(gcn_graph[split_idx:]) - + traceln("%d training graphs -- %d validation graphs"%(len(gcn_graph_train), len(gcn_graph_val))) self._cleanTmpCheckpointFiles() + return gcn_graph_train, gcn_graph_val + + def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, verbose=0): + """ + Return a model trained using the given labelled graphs. + The train method is expected to save the model into self.getModelFilename(), at least at end of training + If bWarmStart==True, The model is loaded from the disk, if any, and if fresher than given timestamp, and training restarts + + if some baseline model(s) were set, they are also trained, using the node features + + """ + + gcn_graph_train, gcn_graph_val = self._prepare_for_train(lGraph, lGraph_vld) + + self._train(gcn_graph_train, gcn_graph_val) + + def _train(self, gcn_graph_train, gcn_graph_val): + """ + In ECN Ensemble mode, many steps have been alread performed by the ECNEnsemble model + """ patience = self.model_config['patience'] if 'patience' in self.model_config else self.model_config['nb_iter'] + #Save the model config in order to restore the model later + fd_mc = open(self.getModelConfigFilename(), 'wb') + pickle.dump(self.model_config, fd_mc) + fd_mc.close() + + #This call the ECN internal constructor and defines the tensorflow graph + tf_graph=tf.Graph() + with tf_graph.as_default(): + self._init_model() + self.tf_graph = tf_graph # SC with tf.Session() as session: # Animesh @@ -422,7 +418,6 @@ def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, #We reopen a session here and load the selected model if we need one self.restore() - def _getBestModelVal(self): val_pickle = self.getValScoreFilename() traceln("\t- reading training info from...",val_pickle) @@ -509,15 +504,7 @@ def load(self, expiration_timestamp=None): self.labelBinarizer = pickle.load(fd_lb) fd_lb.close() - - # SC self._init_model() - # Animesh - tf_graph = tf.Graph() - with tf_graph.as_default(): - self._init_model() - #This create a session containing the correct model - self.restore() - self.tf_graph = tf_graph + self._init_TF_graph() return self @@ -605,6 +592,19 @@ def check (i, l1, l2): + def _prepare_for_test(self, lGraph): + + traceln("\t- computing features on test set") + traceln("\t\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber(lGraph)) + chronoOn() + + lX, lY = self.get_lX_lY(lGraph) + traceln("\t [%.1fs] done\n" % chronoOff()) + + gcn_graph_test = self.convert_lX_lY_to_GCNDataset(lX, lY, training=False,test=True) + + return gcn_graph_test, lX, lY + def test(self, lGraph , lsDocName=None , predict_proba=False @@ -617,21 +617,11 @@ def test(self, lGraph Return a Report object """ #Assume the model was created or loaded - assert lGraph - if self.bConjugate: - lLabelName = lGraph[0].getEdgeLabelNameList() - else: - lLabelName = lGraph[0].getLabelNameList() - traceln("\t- computing features on test set") - traceln("\t\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber(lGraph)) - chronoOn() - - lX, lY = self.get_lX_lY(lGraph) - traceln("\t [%.1fs] done\n" % chronoOff()) - - gcn_graph_test = self.convert_lX_lY_to_GCNDataset(lX, lY, training=False,test=True) + gcn_graph_test, lX, lY = self._prepare_for_test(lGraph) + + lLabelName = lGraph[0].getLabelNameList() chronoOn("test2") session=self.tf_session @@ -690,12 +680,12 @@ def testFiles(self, lsFilename, loadFun,bBaseLine=False): lg = loadFun(sFilename) # returns a singleton list for g in lg: - if self.bConjugate: g.computeEdgeLabels() + if g.bConjugate: g.computeEdgeLabels() [X], [Y] = self.get_lX_lY([g]) gcn_graph_test = self.convert_lX_lY_to_GCNDataset([X], [Y], training=False, test=True) if lLabelName == None: - lLabelName = g.getEdgeLabelNameList() if self.bConjugate else g.getLabelNameList() + lLabelName = g.getEdgeLabelNameList() if g.bConjugate else g.getLabelNameList() traceln("\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber([g])) tNF_EF = (X[0].shape[1], X[2].shape[1]) traceln("node-dim,edge-dim", tNF_EF) @@ -710,7 +700,6 @@ def testFiles(self, lsFilename, loadFun,bBaseLine=False): lX.append(X) lY.append(Y) -# g.detachFromDOM() del g # this can be very large gc.collect() @@ -737,8 +726,8 @@ def restore(self): session=tf.Session(graph=self.tf_graph) session.run(self.gcn_model.init) self.gcn_model.restore_model(session, self.getModelFilename()) - traceln(" ... done loaded",self.sName) - self.tf_session=session + traceln(" ... done loaded ",self.sName) + self.tf_session = session return session @@ -748,6 +737,7 @@ def predict(self, g, bProba=False): return a numpy array, which is a 1-dim array of size the number of nodes of the graph. """ [X] = self.get_lX([g]) + if X[1].shape[0] == 0: raise GraphModelNoEdgeException # no edge in this graph! gcn_graph_test = self.convert_X_to_GCNDataset(X) #lY_pred = self.gcn_model.predict_lG(self.tf_session, gcn_graph_test, verbose=False) @@ -808,10 +798,10 @@ def getModelInfo(self): class DU_Ensemble_ECN(DU_Model_ECN): - sSurname = "ensemble_ecn_" + sSurname = "ecn_ensemble" def __init__(self, sName, sModelDir): - super(DU_Ensemble_ECN,self).__init__(sName,sModelDir) + super(DU_Ensemble_ECN,self).__init__(sName, sModelDir) self.tf_graphs=[] self.models=[] @@ -823,69 +813,52 @@ def _init_model(self): This function is called in the train operation and in the load function for testing on new documents :return: ''' - traceln('Ensemble config') - traceln(self.model_config) - for model_config in self.model_config['ecn_ensemble']: - traceln(model_config) + traceln(' ---- Ensemble config') + assert self.model_config['ratio_train_val'], "The ensemble as a whole must have one ratio 'ratio_train_val'" + for iMdl, model_config in enumerate(self.model_config['ecn_ensemble']): if model_config['type']=='ecn': - sName= model_config['name'] - du_model = DU_Model_ECN(sName,self.sDir) - #Propagate the values of node dim edge dim nb_class - model_config['node_dim']=self.model_config['node_dim'] - model_config['edge_dim'] = self.model_config['edge_dim'] - model_config['nb_class'] = self.model_config['nb_class'] - + sName = self.getSubModelName(iMdl, model_config['name']) + + traceln(' ---- Configuration of ', sName) + du_model = DU_Model_ECN(sName, self.sDir) + du_model.configureLearner(**model_config) - du_model.setTranformers(self.getTransformers()) - du_model.saveTransformers() - #TODO Unclear why this is not set bye default - du_model.setNbClass(self.getNbClass()) + + du_model.model_config['name'] = sName + try: + traceln(" -- Max iteration forced to %d"%du_model.model_config['nb_iter']) + du_model.model_config['nb_iter'] = self.model_config['nb_iter'] + except KeyError: pass + du_model.model_config['ratio_train_val'] = self.model_config['ratio_train_val'] + du_model.model_config['node_dim'] = self.model_config['node_dim'] + du_model.model_config['edge_dim'] = self.model_config['edge_dim'] + du_model.model_config['nb_class'] = self.model_config['nb_class'] + + du_model._init_model() + self.models.append(du_model) else: raise Exception('Invalid ECN Model') - - def train(self, lGraph, bWarmStart=True, expiration_timestamp=None,verbose=0): - traceln('Ensemble ECN Training') - traceln("\t- computing features on training set") - traceln("\t\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber(lGraph)) - chronoOn() - lX, lY = self.get_lX_lY(lGraph) - - self._computeModelCaracteristics(lX) # we discover here dynamically the number of features of nodes and edges - # self._tNF_EF contains the number of node features and edge features - traceln("\t\t %s" % self._getNbFeatureAsText()) - traceln("\t [%.1fs] done\n" % chronoOff()) - - nb_class = self.getNbClass() # Is it better to do Y.shape ? - traceln('nb_class', nb_class) - - self.model_config['node_dim'] = self._tNF_EF[0] - self.model_config['edge_dim'] = self._tNF_EF[1] - self.model_config['nb_class'] = nb_class - traceln("\t- creating the sub-models") - - # TODO - # This converts the lX,lY in the format necessary for GCN Models - #DO we need that , can we share the label binarizer and so on ... - #This sets the label binarizer - gcn_graph = self.convert_lX_lY_to_GCNDataset(lX, lY, training=True) - - # Save the label Binarizer for prediction usage - fd_lb = open(self.getlabelBinarizerFilename(), 'wb') - pickle.dump(self.labelBinarizer, fd_lb) - fd_lb.close() - # Save the model config in order to restore the model later - fd_mc = open(self.getModelConfigFilename(), 'wb') - pickle.dump(self.model_config, fd_mc) - fd_mc.close() - - #This would create all the DU_MODEL + traceln('Ensemble config done: %d models initialized' % len(self.models)) + + def getSubModelName(self, iMdl, sSubModelName): + """ + Make a name that is globally unique and art of the family of names for this ECN Ensemble + """ + return self.sName + "._.%d_" % iMdl + sSubModelName + + def train(self, lGraph, lGraph_vld, bWarmStart=True, expiration_timestamp=None, verbose=0): + + gcn_graph_train, gcn_graph_val = self._prepare_for_train(lGraph, lGraph_vld) + self._init_model() - + + traceln(' ---- Ensemble ECN Training') for du_model in self.models: - #The train will create a tf graph and create the model - du_model.train(lGraph,bWarmStart=bWarmStart) - #TODO assert label binarizer are the same + traceln(" ---- training %s" % du_model.model_config['name']) + du_model._tNF_EF = self._tNF_EF # not sure it is required + du_model._train(gcn_graph_train, gcn_graph_val) +# #TODO assert label binarizer are the same def load(self, expiration_timestamp=None): """ @@ -905,7 +878,7 @@ def load(self, expiration_timestamp=None): fd_mc = open(self.getModelConfigFilename(), 'rb') self.model_config=pickle.load(fd_mc) fd_mc.close() - + fd_lb = open(self.getlabelBinarizerFilename(), 'rb') self.labelBinarizer = pickle.load(fd_lb) fd_lb.close() @@ -915,11 +888,20 @@ def load(self, expiration_timestamp=None): #Still unclear if the load should load all the submodels #In principle yes self._init_model() - for du_model in self.models: + + for iMdl, du_model in enumerate(self.models): #Load each model, init and restore the checkpoint # Create also a corresponding tf.Session - du_model.load() + # du_model.load() the load method does too many inapropriate things + du_model.setTranformers(self.getTransformers()) + du_model.labelBinarizer = self.labelBinarizer + fd_mc = open(du_model.getModelConfigFilename(), 'rb') + du_model.model_config=pickle.load(fd_mc) + fd_mc.close() + + du_model._init_TF_graph() + return self @@ -928,7 +910,7 @@ def getModelInfo(self): Get some basic model info Return a textual report """ - return "Ensemble_Model" + return "Ensemble_ECN_Model" @@ -943,23 +925,25 @@ def test(self, lGraph,lsDocName=None,predict_proba=False): #Assume the model was created or loaded assert lGraph - lLabelName = lGraph[0].getEdgeLabelNameList() if self.bConjugate else lGraph[0].getLabelNameList() - traceln("\t- computing features on test set") - traceln("\t\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber(lGraph)) - chronoOn() - lY = self.get_lY(lGraph) + + gcn_graph_test, lX, lY = self._prepare_for_test(lGraph) + + lLabelName = lGraph[0].getLabelNameList() lY_pred_proba=[] for du_model in self.models: - model_pred=du_model.test(lGraph,lsDocName=lsDocName,predict_proba=True) + model_pred = du_model.gcn_model.predict_prob_lG(du_model.tf_session + , gcn_graph_test + , verbose=False) lY_pred_proba.append(model_pred) - traceln('Number of Models',len(lY_pred_proba)) + traceln(' -- Number of Models : ',len(lY_pred_proba)) + lY_pred,_ = DU_Ensemble_ECN.average_prediction(lY_pred_proba) tstRpt = TestReport(self.sName, lY_pred, lY, lLabelName, lsDocName=lsDocName) # do some garbage collection - del lY + del lX, lY gc.collect() return tstRpt @@ -996,79 +980,73 @@ def testFiles(self, lsFilename, loadFun, bBaseLine=False): Return a Report object """ - raise NotImplementedError - lX, lY, lY_pred = [], [], [] - lLabelName = None traceln("- predicting on test set") chronoOn("testFiles") + lX, lY, lY_pred = [], [], [] + lLabelName = None - # ? Iterate over files or over models - - - for du_model in self.models: - #du_model.load() - - m_pred=[] - #with tf.Session(graph=du_model.tf_graph) as session: - #session.run(du_model.gcn_model.init) - #du_model.gcn_model.restore_model(session, du_model.getModelFilename()) - - for sFilename in lsFilename: - [g] = loadFun(sFilename) # returns a singleton list - if self.bConjugate: g.computeEdgeLabels() + for sFilename in lsFilename: + + lg = loadFun(sFilename) # returns a singleton list + for g in lg: + if g.bConjugate: g.computeEdgeLabels() [X], [Y] = self.get_lX_lY([g]) + + [gcn_graph_test] = self.convert_lX_lY_to_GCNDataset([X], [Y], training=False, test=True) - gcn_graph_test = self.convert_lX_lY_to_GCNDataset([X], [Y], training=False, test=True) if lLabelName == None: - lLabelName = g.getEdgeLabelNameList() if self.bConjugate else g.getLabelNameList() + lLabelName = g.getLabelNameList() traceln("\t #nodes=%d #edges=%d " % Graph.getNodeEdgeTotalNumber([g])) tNF_EF = (X[0].shape[1], X[2].shape[1]) traceln("node-dim,edge-dim", tNF_EF) -# else: -# assert lLabelName == g.getLabelNameList(), "Inconsistency among label spaces" - - model_pred = du_model.test(gcn_graph_test, predict_proba=True) - - # binary only m_pred.append(model_pred[0]) - m_pred.append(model_pred.argmax(axis=1)) + m_pred=[] + for du_model in self.models: + [Y_pred] = du_model.gcn_model.predict_prob_lG(du_model.tf_session, [gcn_graph_test], verbose=False) + m_pred.append([Y_pred]) + + [Y_pred], [_Y_pred_proba] = DU_Ensemble_ECN.average_prediction(m_pred) - if bBaseLine: lX.append(X) + #lX.append(X) lY.append(Y) - g.detachFromDOM() + lY_pred.append(Y_pred) + del _Y_pred_proba + g.detachFromDoc() del g # this can be very large - gc.collect() - lY_pred.append(model_pred) + gc.collect() - lY_pred,_ = DU_Ensemble_ECN.average_prediction(lY_pred) traceln("[%.1fs] done\n" % chronoOff("testFiles")) tstRpt = TestReport(self.sName, lY_pred, lY, lLabelName, lsDocName=lsFilename) - if bBaseLine: - lBaselineTestReport = self._testBaselinesEco(lX, lY, lLabelName, lsDocName=lsFilename) - tstRpt.attach(lBaselineTestReport) - - del lX, lY + del lX, lY, lY_pred gc.collect() return tstRpt - def predict(self, g): + def predict(self, g, bProba=False): """ predict the class of each node of the graph return a numpy array, which is a 1-dim array of size the number of nodes of the graph. """ + + [X] = self.get_lX([g]) + if X[1].shape[0] == 0: raise GraphModelNoEdgeException # no edge in this graph! + + gcn_graph_test = self.convert_X_to_GCNDataset(X) + m_pred=[] for du_model in self.models: - #du_model.load() - ly_pred = du_model.test([g],predict_proba=True) - m_pred.append(ly_pred) - #traceln('Ensemble_predict',m_pred) - y_pred,_ = DU_Ensemble_ECN.average_prediction(m_pred) + assert not du_model.tf_session is None + [Y_pred] = du_model.gcn_model.predict_prob_lG(du_model.tf_session, [gcn_graph_test], verbose=False) + m_pred.append([Y_pred]) + + [y_pred], [y_pred_proba] = DU_Ensemble_ECN.average_prediction(m_pred) - #traceln(y_pred) - #return y_pred[0] - return y_pred.argmax(axis=1) + if bProba: + return y_pred_proba + else: + return y_pred + diff --git a/TranskribusDU/gcn/gcn_models.py b/TranskribusDU/gcn/gcn_models.py index be60b6d..a9d2383 100755 --- a/TranskribusDU/gcn/gcn_models.py +++ b/TranskribusDU/gcn/gcn_models.py @@ -71,8 +71,8 @@ def test_lG(self, session, gcn_graph_test, verbose=True): node_acc = acc_tp / nb_node_total if verbose: - traceln('\tMean Graph Accuracy', '%.4f' % g_acc) - traceln('\tMean Node Accuracy', '%.4f' % node_acc) + traceln('\t -- Mean Graph Accuracy', '%.4f' % g_acc) + traceln('\t -- Mean Node Accuracy', '%.4f' % node_acc) return g_acc,node_acc @@ -119,7 +119,7 @@ def get_nb_params(self): variable_parameters *= dim.value #traceln(variable_parameters) total_parameters += variable_parameters - traceln(total_parameters) + return total_parameters def train_with_validation_set(self,session,graph_train,graph_val,max_iter,eval_iter=10,patience=7,graph_test=None,save_model_path=None): ''' @@ -153,13 +153,13 @@ def train_with_validation_set(self,session,graph_train,graph_val,max_iter,eval_i break if i % eval_iter == 0: - traceln('\nEpoch ', i,' Patience ', wait) + traceln('\n -- Epoch ', i,' Patience ', wait) _, tr_acc = self.test_lG(session, graph_train, verbose=False) traceln(' Train Acc ', '%.4f' % tr_acc) train_accuracies.append(tr_acc) _, node_acc = self.test_lG(session, graph_val, verbose=False) - traceln(' Valid Acc ', '%.4f' % node_acc) + traceln(' -- Valid Acc ', '%.4f' % node_acc) validation_accuracies.append(node_acc) if save_model_path: @@ -167,7 +167,7 @@ def train_with_validation_set(self,session,graph_train,graph_val,max_iter,eval_i if graph_test: test_graph_acc,test_acc = self.test_lG(session, graph_test, verbose=False) - traceln(' Test Acc ', '%.4f' % test_acc,' %.4f' % test_graph_acc) + traceln(' -- Test Acc ', '%.4f' % test_acc,' %.4f' % test_graph_acc) test_accuracies.append(test_acc) if node_acc > best_val_acc: @@ -183,14 +183,14 @@ def train_with_validation_set(self,session,graph_train,graph_val,max_iter,eval_i for g in graph_train: self.train(session, g, n_iter=1) #Final Save - traceln('Stopped Model Training after',stopped_iter) - traceln('Validation Accuracies',['%.4f' % (100*sx) for sx in validation_accuracies]) + traceln(' -- Stopped Model Training after : ',stopped_iter) + traceln(' -- Validation Accuracies : ',['%.4f' % (100*sx) for sx in validation_accuracies]) #traceln('Final Training Accuracy') _,node_train_acc = self.test_lG(session, graph_train) - traceln('Final Training Accuracy','%.4f' % node_train_acc) + traceln(' -- Final Training Accuracy','%.4f' % node_train_acc) - traceln('Final Valid Acc') + traceln(' -- Final Valid Acc') self.test_lG(session, graph_val) R = {} @@ -204,15 +204,15 @@ def train_with_validation_set(self,session,graph_train,graph_val,max_iter,eval_i if graph_test: _, final_test_acc = self.test_lG(session, graph_test) - traceln('Final Test Acc','%.4f' % final_test_acc) + traceln(' -- Final Test Acc','%.4f' % final_test_acc) R['final_test_acc'] = final_test_acc val = R['val_acc'] - traceln('Validation scores', val) + traceln(' -- Validation scores', val) epoch_index = np.argmax(val) - traceln('Best performance on val set: Epoch', epoch_index,val[epoch_index]) - traceln('Test Performance from val', test_accuracies[epoch_index]) + traceln(' -- Best performance on val set: Epoch', epoch_index,val[epoch_index]) + traceln(' -- Test Performance from val', test_accuracies[epoch_index]) return R @@ -258,8 +258,8 @@ def test_lG(self, session, gcn_graph_test, verbose=True): #node_acc = acc_tp / nb_node_total if verbose: - traceln('Mean Graph Accuracy', '%.4f' % g_acc) - traceln('Mean Node Accuracy', '%.4f' % node_acc) + traceln(' -- Mean Graph Accuracy', '%.4f' % g_acc) + traceln(' -- Mean Node Accuracy', '%.4f' % node_acc) return g_acc, node_acc @@ -350,17 +350,16 @@ def create_model(self): self.init = tf.global_variables_initializer() self.saver= tf.train.Saver(max_to_keep=5) - traceln('Number of Params:') - self.get_nb_params() + traceln(' -- Number of Params: ', self.get_nb_params()) def save_model(self, session, model_filename): - traceln("Saving Model") + traceln(" -- Saving Model") save_path = self.saver.save(session, model_filename) def restore_model(self, session, model_filename): self.saver.restore(session, model_filename) - traceln("Model restored.") + traceln(" -- Model restored.") def train(self,session,graph,verbose=False,n_iter=1): @@ -385,7 +384,7 @@ def train(self,session,graph,verbose=False,n_iter=1): Ops =session.run([self.train_step,self.loss], feed_dict=feed_batch) if verbose: - traceln('Training Loss',Ops[1]) + traceln(' -- Training Loss',Ops[1]) @@ -407,7 +406,7 @@ def test(self,session,graph,verbose=True): Ops =session.run([self.loss,self.accuracy], feed_dict=feed_batch) if verbose: - traceln('Test Loss',Ops[0],' Test Accuracy:',Ops[1]) + traceln(' -- Test Loss',Ops[0],' Test Accuracy:',Ops[1]) return Ops[1] @@ -425,7 +424,7 @@ def predict(self,session,graph,verbose=True): } Ops = session.run([self.pred], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:',Ops[0].shape) + traceln(' -- Got Prediction for:',Ops[0].shape) return Ops[0] @@ -488,7 +487,7 @@ def set_learning_options(self,dict_model_config): :param kwargs: :return: """ - traceln(dict_model_config) + #traceln( -- dict_model_config) for attrname,val in dict_model_config.items(): #We treat the activation function differently as we can not pickle/serialiaze python function if attrname=='activation_name': @@ -502,7 +501,7 @@ def set_learning_options(self,dict_model_config): self.stack_instead_add=val if attrname not in self._setter_variables: try: - traceln('set',attrname,val) + traceln(' -- set ',attrname,val) setattr(self,attrname,val) except AttributeError: warnings.warn("Ignored options for ECN"+attrname+':'+val) @@ -543,12 +542,12 @@ def fastconvolve(self,Wedge,Bedge,F,S,T,H,nconv,Sshape,nb_edge,dropout_p_edge,dr elif use_edge_mlp: #Wedge mlp is a shared variable across layer which project edge in a lower dim FW0 = tf.nn.tanh( tf.matmul(F,Wedge_mlp) +Bedge_mlp ) - traceln('FW0', FW0.get_shape()) + traceln(' -- FW0', FW0.get_shape()) FW = tf.matmul(FW0, Wedge, transpose_b=True) + Bedge - traceln('FW', FW.get_shape()) + traceln(' -- FW', FW.get_shape()) else: FW = tf.matmul(F, Wedge, transpose_b=True) + Bedge - traceln('FW', FW.get_shape()) + traceln(' -- FW', FW.get_shape()) self.conv =tf.unstack(FW,axis=1) @@ -630,7 +629,7 @@ def create_model_stack_convolutions(self): tf.random_uniform((self.node_indim * self.nconv_edge + self.node_indim, self.node_indim), -1.0 / math.sqrt(self.node_indim), 1.0 / math.sqrt(self.node_indim)), name='Wnl', dtype=tf.float32) - traceln('Wnli shape', Wnli.get_shape()) + traceln(' -- Wnli shape', Wnli.get_shape()) Bnli = tf.Variable(tf.zeros([self.node_indim]), name='Bnl' + str(i), dtype=tf.float32) Weli = init_glorot([int(self.nconv_edge), int(self.edge_dim)], name='Wel_') @@ -664,11 +663,11 @@ def create_model_stack_convolutions(self): self.ND = tf.diag(self.node_dropout_ind) edge_dropout = self.dropout_rate_edge > 0.0 or self.dropout_rate_edge_feat > 0.0 - traceln('Edge Dropout', edge_dropout, self.dropout_rate_edge, self.dropout_rate_edge_feat) + traceln(' -- Edge Dropout', edge_dropout, self.dropout_rate_edge, self.dropout_rate_edge_feat) if self.num_layers == 1: self.H = self.activation(tf.add(tf.matmul(self.node_input, self.Wnl0), self.Bnl0)) self.hidden_layers = [self.H] - traceln("H shape", self.H.get_shape()) + traceln(" -- H shape", self.H.get_shape()) P = self.fastconvolve(self.Wel0, self.Bel0, self.F, self.Ssparse, self.Tsparse, self.H, self.nconv_edge, self.Sshape, self.nb_edge, @@ -783,11 +782,11 @@ def create_model_sum_convolutions(self): self.ND = tf.diag(self.node_dropout_ind) edge_dropout = self.dropout_rate_edge> 0.0 or self.dropout_rate_edge_feat > 0.0 - traceln('Edge Dropout',edge_dropout, self.dropout_rate_edge,self.dropout_rate_edge_feat) + traceln(' -- Edge Dropout',edge_dropout, self.dropout_rate_edge,self.dropout_rate_edge_feat) if self.num_layers==1: self.H = self.activation(tf.add(tf.matmul(self.node_input, self.Wnl0), self.Bnl0)) self.hidden_layers = [self.H] - traceln("H shape",self.H.get_shape()) + traceln(" -- H shape",self.H.get_shape()) P = self.fastconvolve(self.Wel0,self.Bel0,self.F,self.Ssparse,self.Tsparse,self.H,self.nconv_edge,self.Sshape,self.nb_edge, @@ -898,7 +897,7 @@ def create_model(self): self.Bel0 = tf.Variable(0.01 * tf.ones([self.nconv_edge]), name='Bel0', dtype=tf.float32) - traceln('Wel0', self.Wel0.get_shape()) + traceln(' -- Wel0', self.Wel0.get_shape()) self.train_var.extend([self.Wnl0, self.Bnl0]) self.train_var.append(self.Wel0) @@ -927,8 +926,7 @@ def create_model(self): self.init = tf.global_variables_initializer() self.saver = tf.train.Saver(max_to_keep=0) - traceln('Number of Params:') - self.get_nb_params() + traceln(' -- Number of Params: ', self.get_nb_params()) def create_model_old(self): ''' @@ -996,7 +994,7 @@ def create_model_old(self): #RF self.zel0 = tf.Variable(tf.ones([self.nconv_edge]), name='zel0' , dtype=tf.float32) #RF self.zH = tf.Variable(tf.ones([self.num_layers]),name='zH',dtype=tf.float32) - traceln('Wel0',self.Wel0.get_shape()) + traceln(' -- Wel0',self.Wel0.get_shape()) self.train_var.extend([self.Wnl0,self.Bnl0]) self.train_var.append(self.Wel0) @@ -1027,7 +1025,7 @@ def create_model_old(self): Wnli =tf.Variable(tf.random_uniform( (self.node_indim*self.nconv_edge+self.node_indim, self.node_indim), -1.0 / math.sqrt(self.node_indim), 1.0 / math.sqrt(self.node_indim)),name='Wnl',dtype=tf.float32) - traceln('Wnli shape',Wnli.get_shape()) + traceln(' -- Wnli shape',Wnli.get_shape()) elif self.use_conv_weighted_avg: Wnli = tf.Variable( @@ -1037,7 +1035,7 @@ def create_model_old(self): #Wnli = tf.eye(self.node_dim,dtype=tf.float32) - traceln('Wnli shape', Wnli.get_shape()) + traceln(' -- Wnli shape', Wnli.get_shape()) else: @@ -1121,7 +1119,7 @@ def create_model_old(self): self.ND = tf.diag(self.node_dropout_ind) edge_dropout = self.dropout_rate_edge> 0.0 or self.dropout_rate_edge_feat > 0.0 - traceln('Edge Dropout',edge_dropout, self.dropout_rate_edge,self.dropout_rate_edge_feat) + traceln(' -- Edge Dropout',edge_dropout, self.dropout_rate_edge,self.dropout_rate_edge_feat) if self.num_layers==1: self.H = self.activation(tf.add(tf.matmul(self.node_input, self.Wnl0), self.Bnl0)) self.hidden_layers = [self.H] @@ -1180,8 +1178,8 @@ def create_model_old(self): self.Hnode_layers.append(Hi_) - # traceln('Hi_shape',Hi_.get_shape()) - # traceln('Hi prevous shape',self.hidden_layers[-1].get_shape()) + # traceln(' -- Hi_shape',Hi_.get_shape()) + # traceln(' -- Hi prevous shape',self.hidden_layers[-1].get_shape()) P = self.fastconvolve(self.Wed_layers[i],self.Bed_layers[i], self.F, self.Ssparse, self.Tsparse, Hi_, self.nconv_edge,self.Sshape, self.nb_edge, self.dropout_p_edge,self.dropout_p_edge_feat, stack=self.stack_instead_add, use_dropout=edge_dropout, @@ -1236,8 +1234,7 @@ def create_model_old(self): self.init = tf.global_variables_initializer() self.saver= tf.train.Saver(max_to_keep=0) - traceln('Number of Params:') - self.get_nb_params() + traceln(' -- Number of Params: ', self.get_nb_params()) def save_model(self, session, model_filename): @@ -1279,8 +1276,8 @@ def train(self,session,graph,verbose=False,n_iter=1): ''' #TrainEvalSet Here for i in range(n_iter): - #traceln('Train',X.shape,EA.shape) - #traceln('DropoutEdges',self.dropout_rate_edge) + #traceln(' -- Train',X.shape,EA.shape) + #traceln(' -- DropoutEdges',self.dropout_rate_edge) feed_batch = { self.nb_node: graph.X.shape[0], @@ -1300,7 +1297,7 @@ def train(self,session,graph,verbose=False,n_iter=1): Ops =session.run([self.train_step,self.loss], feed_dict=feed_batch) if verbose: - traceln('Training Loss',Ops[1]) + traceln(' -- Training Loss',Ops[1]) @@ -1334,7 +1331,7 @@ def test(self,session,graph,verbose=True): Ops =session.run([self.loss,self.accuracy], feed_dict=feed_batch) if verbose: - traceln('Test Loss',Ops[0],' Test Accuracy:',Ops[1]) + traceln(' -- Test Loss',Ops[0],' Test Accuracy:',Ops[1]) return Ops[1] @@ -1365,7 +1362,7 @@ def predict(self,session,graph,verbose=True): } Ops = session.run([self.pred], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:',Ops[0].shape) + traceln(' -- Got Prediction for:',Ops[0].shape) print(str(Ops)) return Ops[0] @@ -1396,7 +1393,7 @@ def prediction_prob(self,session,graph,verbose=True): } Ops = session.run([self.predict_proba], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:',Ops[0].shape) + traceln(' -- Got Prediction for:',Ops[0].shape) return Ops[0] @@ -1437,13 +1434,13 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p break if i % eval_iter == 0: - traceln('\nEpoch', i) + traceln('\n -- Epoch', i) _, tr_acc = self.test_lG(session, graph_train, verbose=False) - traceln(' Train Acc', '%.4f' % tr_acc) + traceln(' -- Train Acc', '%.4f' % tr_acc) train_accuracies.append(tr_acc) _, node_acc = self.test_lG(session, graph_val, verbose=False) - traceln(' Valid Acc', '%.4f' % node_acc) + traceln(' -- Valid Acc', '%.4f' % node_acc) validation_accuracies.append(node_acc) if save_model_path: @@ -1451,7 +1448,7 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p if graph_test: _, test_acc = self.test_lG(session, graph_test, verbose=False) - traceln(' Test Acc', '%.4f' % test_acc) + traceln(' -- Test Acc', '%.4f' % test_acc) test_accuracies.append(test_acc) @@ -1485,14 +1482,14 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p # save_path = self.saver.save(session, save_model_path, global_step=i) # TODO Add the final step mean_acc = [] - traceln('Stopped Model Training after', stopped_iter) - traceln('Val Accuracies', validation_accuracies) + traceln(' -- Stopped Model Training after ', stopped_iter) + traceln(' -- Val Accuracies ', validation_accuracies) - traceln('Final Training Accuracy') + traceln(' -- Final Training Accuracy') _, node_train_acc = self.test_lG(session, graph_train) - traceln('Train Mean Accuracy', '%.4f' % node_train_acc) + traceln(' -- Train Mean Accuracy', '%.4f' % node_train_acc) - traceln('Final Valid Acc') + traceln(' -- Final Valid Acc') self.test_lG(session, graph_val) R = {} @@ -1504,7 +1501,7 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p # R['W_edge'] =self.get_Wedge(session) if graph_test: _, final_test_acc = self.test_lG(session, graph_test) - traceln('Final Test Acc', '%.4f' % final_test_acc) + traceln(' -- Final Test Acc', '%.4f' % final_test_acc) R['final_test_acc'] = final_test_acc return R @@ -1641,8 +1638,7 @@ def create_model(self): self.train_step = self.optalg.apply_gradients(self.grads_and_vars) - traceln('Number of Parameters') - self.get_nb_params() + traceln(' -- Number of Params: ', self.get_nb_params()) # Add ops to save and restore all the variables. self.init = tf.global_variables_initializer() @@ -1660,7 +1656,7 @@ def train(self,session,g,n_iter=1,verbose=False): } Ops =session.run([self.train_step,self.loss], feed_dict=feed_batch) if verbose: - traceln('Training Loss',Ops[1]) + traceln(' -- Training Loss',Ops[1]) @@ -1675,7 +1671,7 @@ def test(self,session,g,verbose=True): } Ops =session.run([self.loss,self.accuracy], feed_dict=feed_batch) if verbose: - traceln('Test Loss',Ops[0],' Test Accuracy:',Ops[1]) + traceln(' -- Test Loss',Ops[0],' Test Accuracy:',Ops[1]) return Ops[1] @@ -1695,7 +1691,7 @@ def train(self,session,graph,verbose=False,n_iter=1): ''' #TrainEvalSet Here for i in range(n_iter): - #traceln('Train',X.shape,EA.shape) + #traceln(' -- Train',X.shape,EA.shape) nb_edge =graph.E.shape[0] half_edge =nb_edge/2 @@ -1712,7 +1708,7 @@ def train(self,session,graph,verbose=False,n_iter=1): Ops =session.run([self.train_step,self.loss], feed_dict=feed_batch) if verbose: - traceln('Training Loss',Ops[1]) + traceln(' -- Training Loss',Ops[1]) @@ -1741,7 +1737,7 @@ def test(self,session,graph,verbose=True): Ops =session.run([self.loss,self.accuracy], feed_dict=feed_batch) if verbose: - traceln('Test Loss',Ops[0],' Test Accuracy:',Ops[1]) + traceln(' -- Test Loss',Ops[0],' Test Accuracy:',Ops[1]) return Ops[1] @@ -1767,7 +1763,7 @@ def predict(self,session,graph,verbose=True): } Ops = session.run([self.pred], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:',Ops[0].shape) + traceln(' -- Got Prediction for:',Ops[0].shape) return Ops[0] @@ -1852,7 +1848,7 @@ def set_learning_options(self,dict_model_config): self.stack_instead_add=val if attrname not in self._setter_variables: try: - traceln('set',attrname,val) + traceln(' -- set',attrname,val) setattr(self,attrname,val) except AttributeError: warnings.warn("Ignored options for ECN"+attrname+':'+val) @@ -1899,7 +1895,7 @@ def dense_graph_attention_layer(self,H,W,A,nb_node,dropout_attention,dropout_nod # dropout is done after the softmax if use_dropout: - traceln('... using dropout for attention layer') + traceln(' -- ... using dropout for attention layer') alphasD = tf.nn.dropout(alphas, 1.0 - dropout_attention) P_D = tf.nn.dropout(P, 1.0 - dropout_node) alphasP = tf.matmul(alphasD, P_D) @@ -1951,7 +1947,7 @@ def simple_graph_attention_layer(self,H,W,A,S,T,Adjind,Sshape,nb_edge, SD = tf.sparse_reorder(SD) #shape,(nb_edge,nb_node) # This tensor has shape(nb_edge,in_dim) and contains the node source projection, ie Wh SP = tf.sparse_tensor_dense_matmul(tf.sparse_transpose(SD), P,name='SP') #shape(nb_edge,in_dim) - #traceln('SP', SP.get_shape()) + #traceln(' -- SP', SP.get_shape()) #Deprecated @@ -2008,7 +2004,7 @@ def simple_graph_attention_layer(self,H,W,A,S,T,Adjind,Sshape,nb_edge, #dropout is done after the softmax if use_dropout: - traceln('... using dropout for attention layer') + traceln(' -- ... using dropout for attention layer') alphasD = tf.SparseTensor(indices=alphas.indices,values=tf.nn.dropout(alphas.values, 1.0 - dropout_attention),dense_shape=alphas.dense_shape) P_D =tf.nn.dropout(P,1.0-dropout_node) alphasP = tf.sparse_tensor_dense_matmul(alphasD, P_D) @@ -2287,9 +2283,7 @@ def create_model(self): self.init = tf.global_variables_initializer() self.saver= tf.train.Saver(max_to_keep=0) - traceln('Number of Params:') - self.get_nb_params() - + traceln(' -- Number of Params: ', self.get_nb_params()) #TODO Move in MultigraphNN @@ -2312,8 +2306,8 @@ def train(self,session,graph,verbose=False,n_iter=1): ''' #TrainEvalSet Here for i in range(n_iter): - #traceln('Train',X.shape,EA.shape) - #traceln('DropoutEdges',self.dropout_rate_edge) + #traceln(' -- Train',X.shape,EA.shape) + #traceln(' -- DropoutEdges',self.dropout_rate_edge) Aind = np.array(np.stack([graph.Sind[:, 0], graph.Tind[:, 1]], axis=-1), dtype='int64') feed_batch = { @@ -2333,7 +2327,7 @@ def train(self,session,graph,verbose=False,n_iter=1): Ops =session.run([self.train_step,self.loss], feed_dict=feed_batch) if verbose: - traceln('Training Loss',Ops[1]) + traceln(' -- Training Loss',Ops[1]) @@ -2367,7 +2361,7 @@ def test(self,session,graph,verbose=True): Ops =session.run([self.loss,self.accuracy], feed_dict=feed_batch) if verbose: - traceln('Test Loss',Ops[0],' Test Accuracy:',Ops[1]) + traceln(' -- Test Loss',Ops[0],' Test Accuracy:',Ops[1]) return Ops[1] @@ -2397,7 +2391,7 @@ def predict(self,session,graph,verbose=True): } Ops = session.run([self.pred], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:',Ops[0].shape) + traceln(' -- Got Prediction for:',Ops[0].shape) return Ops[0] def prediction_prob(self, session, graph, verbose=True): @@ -2426,7 +2420,7 @@ def prediction_prob(self, session, graph, verbose=True): } Ops = session.run([self.predict_proba], feed_dict=feed_batch) if verbose: - traceln('Got Prediction for:', Ops[0].shape) + traceln(' -- Got Prediction for:', Ops[0].shape) return Ops[0] #TODO Move that MultiGraphNN @@ -2467,13 +2461,13 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p break if i % eval_iter == 0: - traceln('\nEpoch', i) + traceln('\n -- Epoch', i) _, tr_acc = self.test_lG(session, graph_train, verbose=False) - traceln(' Train Acc', '%.4f' % tr_acc) + traceln(' -- Train Acc', '%.4f' % tr_acc) train_accuracies.append(tr_acc) _, node_acc = self.test_lG(session, graph_val, verbose=False) - traceln(' Valid Acc', '%.4f' % node_acc) + traceln(' -- Valid Acc', '%.4f' % node_acc) validation_accuracies.append(node_acc) if save_model_path: @@ -2481,7 +2475,7 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p if graph_test: _, test_acc = self.test_lG(session, graph_test, verbose=False) - traceln(' Test Acc', '%.4f' % test_acc) + traceln(' -- Test Acc', '%.4f' % test_acc) test_accuracies.append(test_acc) # TODO min_delta @@ -2504,14 +2498,14 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p # save_path = self.saver.save(session, save_model_path, global_step=i) # TODO Add the final step mean_acc = [] - traceln('Stopped Model Training after', stopped_iter) - traceln('Val Accuracies', validation_accuracies) + traceln(' -- Stopped Model Training after', stopped_iter) + traceln(' -- Val Accuracies', validation_accuracies) - traceln('Final Training Accuracy') + traceln(' -- Final Training Accuracy') _, node_train_acc = self.test_lG(session, graph_train) - traceln('Train Mean Accuracy', '%.4f' % node_train_acc) + traceln(' -- Train Mean Accuracy', '%.4f' % node_train_acc) - traceln('Final Valid Acc') + traceln(' -- Final Valid Acc') self.test_lG(session, graph_val) R = {} @@ -2523,7 +2517,7 @@ def train_All_lG(self,session,graph_train,graph_val, max_iter, eval_iter = 10, p # R['W_edge'] =self.get_Wedge(session) if graph_test: _, final_test_acc = self.test_lG(session, graph_test) - traceln('Final Test Acc', '%.4f' % final_test_acc) + traceln(' -- Final Test Acc', '%.4f' % final_test_acc) R['final_test_acc'] = final_test_acc return R diff --git a/TranskribusDU/gcn/tests/test_argmax.py b/TranskribusDU/gcn/tests/test_argmax.py new file mode 100644 index 0000000..991f35f --- /dev/null +++ b/TranskribusDU/gcn/tests/test_argmax.py @@ -0,0 +1,14 @@ +import numpy as np + +# ----- TESTS --------------------- +def test_argmax(): + a = np.array([ [0, 1, 2] + , [2, 0, 0] + , [0, 1, 0]]) + v = a.argmax(axis=1) + assert v.tolist() == [2, 0, 1], v + assert (v == np.array([2, 0, 1])).all() + +if __name__ == "__main__": + test_argmax() + \ No newline at end of file diff --git a/TranskribusDU/graph/BaselineModel.py b/TranskribusDU/graph/BaselineModel.py index f12b393..40bf3ee 100644 --- a/TranskribusDU/graph/BaselineModel.py +++ b/TranskribusDU/graph/BaselineModel.py @@ -5,18 +5,7 @@ Copyright NAVER(C) 2018, 2019 Hervé Déjean, Jean-Luc Meunier, Animesh Prasad - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -98,7 +87,7 @@ def testFiles(self, lsFilename, loadFun): for sFilename in lsFilename: [g] = loadFun(sFilename) #returns a singleton list - if self.bConjugate: g.computeEdgeLabels() + if g.bConjugate: g.computeEdgeLabels() X, Y = self.transformGraphs([g], True) if lLabelName == None: @@ -115,7 +104,7 @@ def testFiles(self, lsFilename, loadFun): lY_pred.append(Y_pred[0]) #Choose with Y_pred is a list of predictions of feach model - g.detachFromDOM() + g.detachFromDoc() del g #this can be very large del X,X_node diff --git a/TranskribusDU/graph/Block.py b/TranskribusDU/graph/Block.py index 9f354d8..da4df60 100644 --- a/TranskribusDU/graph/Block.py +++ b/TranskribusDU/graph/Block.py @@ -33,7 +33,10 @@ """ class Block: - + + # when creating edges, blocks are aligned on a grid. This is the grid size + iGRID = 2 + def __init__(self, page, tXYWH, text, orientation, cls, nodeType, domnode=None, domid=None): """ pnum is an int @@ -75,6 +78,19 @@ def __init__(self, page, tXYWH, text, orientation, cls, nodeType, domnode=None, def setFontSize(self, fFontSize): self.fontsize = fFontSize + @classmethod + def setGrid(cls, iGrid): + """ + Blocks are aligned on a grid when computing the edges + """ + assert iGrid > 0 + cls.iGRID = iGrid + + def setShape(self,s): + self.shape = s + + def getShape(self): return self.shape + def detachFromDOM(self): """ Erase any pointer to the DOM so that we can free it. @@ -106,7 +122,12 @@ def getCenter(self): def area(self): return(self.x2-self.x1) * (self.y2-self.y1) - + def scale(self, scale_h, scale_v): + dx = (self.x2-self.x1) * (1-scale_h) / 2 + self.x1, self.x2 = self.x1 + dx, self.x2 - dx + dy = (self.y2-self.y1) * (1-scale_v) / 2 + self.y1, self.y2 = self.y1 + dy, self.y2 - dy + def setThickBox(self, f): """make the box border thicker """ self.x1 = self.x1 - f @@ -380,8 +401,8 @@ def findConsecPageOverlapEdges(cls, lPrevPageEdgeBlk, lPageBlk, bMirror=True, ep # # dBlk_by_alpha = collections.defaultdict(list) # x1-y1 --> [ list of block having that x1-y1 ] # for blk in lPageBlk: -# rx1 = cls.epsilonRound(blk.x1, epsilon) -# ry1 = cls.epsilonRound(blk.y1, epsilon) +# rx1 = cls.gridRound(blk.x1, epsilon) +# ry1 = cls.gridRound(blk.y1, epsilon) # #rx1, ry1 = blk.x1, blk.y1 # #OK assert abs(ry1-b.y1) < epsilon # dBlk_by_alpha[rx1 - ry1].append(blk) @@ -417,9 +438,9 @@ def rotateMinus90deg(self): def rotatePlus90deg(self): self.x1, self.y1, self.x2, self.y2 = self.y1, -self.x2, self.y2, -self.x1 - def epsilonRound(cls, f, epsilon): - return int(round(f / epsilon, 0)*epsilon) - epsilonRound = classmethod(epsilonRound) + def gridRound(cls, f, iGrid): + return int(round(f / iGrid, 0)*iGrid) + gridRound = classmethod(gridRound) def XXOverlap(cls, tAx1_Ax2, tBx1_Bx2): #overlap if the max is smaller than the min Ax1, Ax2 = tAx1_Ax2 @@ -430,16 +451,16 @@ def XXOverlap(cls, tAx1_Ax2, tBx1_Bx2): #overlap if the max is smaller than the @classmethod - def _findVerticalNeighborEdges_init(cls, lBlk, epsilon): - assert type(epsilon) is int, repr(epsilon) + def _findVerticalNeighborEdges_init(cls, lBlk, iGrid): + assert type(iGrid) is int, repr(iGrid) #index along the y axis based on y1 and y2 dBlk_Y1 = collections.defaultdict(list) # y1 --> [ list of block having that y1 ] setY2 = set() # set of (unique) y2 for blk in lBlk: - ry1 = cls.epsilonRound(blk.y1, epsilon) - ry2 = cls.epsilonRound(blk.y2, epsilon) - #OK assert abs(ry1-b.y1) < epsilon + ry1 = cls.gridRound(blk.y1, iGrid) + ry2 = cls.gridRound(blk.y2, iGrid) + #OK assert abs(ry1-b.y1) < iGrid dBlk_Y1[ry1].append(blk) setY2.add(ry2) @@ -458,7 +479,7 @@ def _findVerticalNeighborEdges_init(cls, lBlk, epsilon): return n1, lY1, dBlk_Y1, di1_by_y2 @classmethod - def _findVerticalNeighborEdges_g1(cls, lBlk, EdgeClass, bShortOnly=False, epsilon = 2): + def _findVerticalNeighborEdges_g1(cls, lBlk, EdgeClass, bShortOnly=False, iGrid = None): """ any dimension smaller than 5 is zero, we assume that no block are narrower than this value @@ -468,20 +489,21 @@ def _findVerticalNeighborEdges_g1(cls, lBlk, EdgeClass, bShortOnly=False, epsilo """ if not lBlk: return [] + if iGrid is None: iGrid = Block.iGRID #look for vertical neighbors lVEdge = list() - n1, lY1, dBlk_Y1, di1_by_y2 = cls._findVerticalNeighborEdges_init(lBlk, epsilon) + n1, lY1, dBlk_Y1, di1_by_y2 = cls._findVerticalNeighborEdges_init(lBlk, iGrid) - epsilon2 = 2*epsilon - epsilon2 = 0 # back to old version for being able to compare results + epsilon = 2*iGrid + epsilon = 0 # back to old version for being able to compare results for i1,y1 in enumerate(lY1): #start with the block(s) with lowest y1 # (they should not overlap horizontally and cannot be vertical neighbors to each other) for A in dBlk_Y1[y1]: - Ax1,Ay1, Ax2,Ay2 = map(cls.epsilonRound, A.getBB(), [epsilon, epsilon, epsilon, epsilon]) + Ax1,Ay1, Ax2,Ay2 = map(cls.gridRound, A.getBB(), [iGrid, iGrid, iGrid, iGrid]) A_height = A.y2 - A.y1 #why were we accessing the DOM?? float(A.node.prop("height")) assert Ay2 >= Ay1 lOx1x2 = list() #list of observed overlaps for current block A @@ -492,10 +514,10 @@ def _findVerticalNeighborEdges_g1(cls, lBlk, EdgeClass, bShortOnly=False, epsilo for j1 in range(jstart, n1): #take in turn all Y1 below A By1 = lY1[j1] for B in dBlk_Y1[By1]: #all block starting at that y1 - Bx1,By1, Bx2,_ = map(cls.epsilonRound, B.getBB(), [epsilon, epsilon, epsilon, epsilon]) + Bx1,By1, Bx2,_ = map(cls.gridRound, B.getBB(), [iGrid, iGrid, iGrid, iGrid]) #ovABx1, ovABx2 = cls.XXOverlap( (Ax1,Ax2), (Bx1, Bx2) ) ovABx1, ovABx2 = max(Ax1,Bx1), min(Ax2, Bx2) - if ovABx2 - ovABx1 > epsilon2: # significantoverlap + if ovABx2 - ovABx1 > epsilon: # significantoverlap #we now check if that B block is not partially hidden by a previous overlapping block bVisible = True for ovOx1, ovOx2 in lOx1x2: @@ -525,7 +547,7 @@ def _findVerticalNeighborEdges_g1(cls, lBlk, EdgeClass, bShortOnly=False, epsilo return lVEdge @classmethod - def _findVerticalNeighborEdges_g2(cls, lBlk, EdgeClass, bShortOnly=False, epsilon = 2): + def _findVerticalNeighborEdges_g2(cls, lBlk, EdgeClass, bShortOnly=False, iGrid = None): """ the masking is done properly. @@ -534,26 +556,27 @@ def _findVerticalNeighborEdges_g2(cls, lBlk, EdgeClass, bShortOnly=False, epsilo return a list of pair of block """ if not lBlk: return [] + if iGrid is None: iGrid = Block.iGRID + #look for vertical neighbors lVEdge = list() - n1, lY1, dBlk_Y1, di1_by_y2 = cls._findVerticalNeighborEdges_init(lBlk, epsilon) + n1, lY1, dBlk_Y1, _di1_by_y2 = cls._findVerticalNeighborEdges_init(lBlk, iGrid) + # we do not use _di1_by_y2 because we want to include vertically overlapping block in our search for i1,y1 in enumerate(lY1): #start with the block(s) with lowest y1 # (they should not overlap horizontally and cannot be vertical neighbors to each other) for A in dBlk_Y1[y1]: - Ax1,Ay1, Ax2,Ay2 = map(cls.epsilonRound, A.getBB(), [epsilon, epsilon, epsilon, epsilon]) + Ax1,Ay1, Ax2,Ay2 = map(cls.gridRound, A.getBB(), [iGrid, iGrid, iGrid, iGrid]) A_height = A.y2 - A.y1 #why were we accessing the DOM?? float(A.node.prop("height")) lViewA = [(Ax1, Ax2)] # what A can view (or what it covers horizontally) assert Ay2 >= Ay1 - jstart = di1_by_y2[Ay2] #index of y1 in lY1 of next block below A (because its y1 is larger than A.y2) - jstart = jstart - 1 #because some block overlap each other, we try the previous index (if it is not the current index) - jstart = max(jstart, i1+1) # but for sure we want the next group of y1 + jstart = i1 + 1 # consider all block slightly below the current one for j1 in range(jstart, n1): #take in turn all Y1 below A By1 = lY1[j1] for B in dBlk_Y1[By1]: #all block starting at that y1 - Bx1,By1, Bx2,_ = map(cls.epsilonRound, B.getBB(), [epsilon, epsilon, epsilon, epsilon]) + Bx1,By1, Bx2,_ = map(cls.gridRound, B.getBB(), [iGrid, iGrid, iGrid, iGrid]) lNewViewA, ovrl = applyMask2(lViewA, [(Bx1, Bx2)]) # what remains of A views... if lNewViewA == lViewA: @@ -606,3 +629,17 @@ def __init__(self, blk): def getOrigBlock(self): return self._blk + + + +def test_scale(): + class Page: + pnum = 0 + b = Block(Page(), (1, 10, 100, 1000), "", 0, None, None) + ref = b.getBB() + b.scale(0.1, 0.1) + assert b.getBB() == (46, 460, 56, 560) + b.scale(10, 10) + assert b.getBB() == ref + + \ No newline at end of file diff --git a/TranskribusDU/graph/Cluster.py b/TranskribusDU/graph/Cluster.py new file mode 100644 index 0000000..0daed96 --- /dev/null +++ b/TranskribusDU/graph/Cluster.py @@ -0,0 +1,56 @@ +# coding: utf8 + +""" +Cluster class for the storing the result of segmentation (by Conjugate graphs for now) + +JL Meunier, Nov 2019 + + +Copyright Naver 2019 +""" + +class ClusterList(list): + """ + A list of Cluster + """ + def __init__(self, lItem=list(), sAlgo="_undefined_"): + super(ClusterList, self).__init__(lItem) + self.sAlgo = sAlgo + + +class Cluster(set): + """ + A cluster is a set of objects or object's id, with a fe specific attribute and methods + """ + def __init__(self, lItem=set(), fProba=1.0): + super(Cluster, self).__init__(lItem) + self.fProba = fProba + + +def test_ClusterList(): + cl = ClusterList() + assert cl == [] + + cl = ClusterList([1, 99, 23], "toto") + assert cl == [1, 99, 23] + + assert cl.sAlgo == "toto" + + cl.append(0) + assert cl == [1, 99, 23, 0] + cl.sort() + assert cl == [0, 1, 23, 99] + + +def test_Cluster(): + c = Cluster() + assert c == set([]) + + c = Cluster([1, 99, 23]) + assert c == set([1, 23, 99, 23]) + + assert c.fProba == 1.0 + + c.add(0) + assert c == set([1, 23, 99, 0]) + \ No newline at end of file diff --git a/TranskribusDU/graph/Edge.py b/TranskribusDU/graph/Edge.py index 0ea0079..dfb3e68 100644 --- a/TranskribusDU/graph/Edge.py +++ b/TranskribusDU/graph/Edge.py @@ -39,7 +39,27 @@ def revertDirection(self): revert the direction of the edge """ self.A, self.B = self.B, self.A - + + + def computeOverlap(self): + """ + compute the overlap between the two nodes + return 0 or a positive number in case of overlap + """ + return 0 + + def computeOverlapPosition(self): + """ + compute the overlap between the two nodes and its position + relative to each node. + The overlap is a length + The position is a number in [-1, +1] relative to the center of a node. + -1 denote left or top, +1 denotes right or bottom. + The position is the position of the center of overlap. + return a tuple: (overlap, pos_on_A, pos_on_B) + """ + return 0, 0, 0 + # ------------------------------------------------------------------------------------------------------------------------------------ #specific code for the CRF graph def computeEdges(cls, lPrevPageEdgeBlk, lPageBlk, iGraphMode, bShortOnly=False): @@ -117,6 +137,35 @@ def __init__(self, A, B, length, overlap): except ZeroDivisionError: self.iou = 0 + def computeOverlap(self): + """ + compute the vertical overlap between the two nodes + return 0 or a positive number in case of overlap + """ + return max(0, min(self.A.y2, self.B.y2) - max(self.A.y1, self.B.y1)) + + def computeOverlapPosition(self): + """ + compute the vertical overlap between the two nodes and its position + relative to each node. + The overlap is a length + The position is a number in [-1, +1] relative to the center of a node. + -1 denote left or top, +1 denotes right or bottom. + The position is the position of the center of overlap. + return a tuple: (overlap, pos_on_A, pos_on_B) + """ + y1 = max(self.A.y1, self.B.y1) + y2 = min(self.A.y2, self.B.y2) + ovrl = max(0, y2 - y1) + if ovrl > 0: + m = (y1 + y2) / 2.0 + pA = (m + m - self.A.y1 - self.A.y2) / abs(self.A.y2 - self.A.y1) + pB = (m + m - self.B.y1 - self.B.y2) / abs(self.B.y2 - self.B.y1) + return (m, pA, pB) + else: + return 0, 0, 0 + + class VerticalEdge(SamePageEdge): def __init__(self, A, B, length, overlap): SamePageEdge.__init__(self, A, B, length, overlap) @@ -125,4 +174,32 @@ def __init__(self, A, B, length, overlap): except ZeroDivisionError: self.iou = 0 + def computeOverlap(self): + """ + compute the horizontal overlap between the two nodes + return 0 or a positive number in case of overlap + """ + return max(0, min(self.A.x2, self.B.x2) - max(self.A.x1, self.B.x1)) + + def computeOverlapPosition(self): + """ + compute the horizontal overlap between the two nodes and its position + relative to each node. + The overlap is a length + The position is a number in [-1, +1] relative to the center of a node. + -1 denote left or top, +1 denotes right or bottom. + The position is the position of the center of overlap. + return a tuple: (overlap, pos_on_A, pos_on_B) + """ + x1 = max(self.A.x1, self.B.x1) + x2 = min(self.A.x2, self.B.x2) + ovrl = max(0, x2 - x1) + if ovrl > 0: + m = (x1 + x2) / 2.0 + pA = (m + m - self.A.x1 - self.A.x2) / abs(self.A.x2 - self.A.x1) + pB = (m + m - self.B.x1 - self.B.x2) / abs(self.B.x2 - self.B.x1) + return (m, pA, pB) + else: + return 0, 0, 0 + diff --git a/TranskribusDU/graph/FeatureDefinition.py b/TranskribusDU/graph/FeatureDefinition.py index ebadd30..1a2659b 100644 --- a/TranskribusDU/graph/FeatureDefinition.py +++ b/TranskribusDU/graph/FeatureDefinition.py @@ -6,18 +6,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/FeatureDefinition_Generic.py b/TranskribusDU/graph/FeatureDefinition_Generic.py new file mode 100644 index 0000000..9ff40f2 --- /dev/null +++ b/TranskribusDU/graph/FeatureDefinition_Generic.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- + +""" + Standard PageXml features: + - not using the page information + - using a QuantileTransformer for numerical features instead of a StandardScaler + + No link with DOm or JSON => named GENERIC + + Copyright Xerox(C) 2016, 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import numpy as np + +from sklearn.pipeline import Pipeline, FeatureUnion +from sklearn.feature_extraction.text import TfidfVectorizer + +#not robust to empty arrays, so use our robust intermediary class instead +#from sklearn.preprocessing import StandardScaler +from .Transformer import SparseToDense +from .FeatureDefinition import FeatureDefinition +from .Transformer import EmptySafe_QuantileTransformer as QuantileTransformer + +from .Transformer_Generic import NodeTransformerTextEnclosed +from .Transformer_Generic import NodeTransformerTextLen +from .Transformer_Generic import NodeTransformerXYWH +from .Transformer_Generic import NodeTransformerNeighbors +from .Transformer_Generic import EdgeBooleanAlignmentFeatures +from .Transformer_Generic import EdgeNumericalSelector_noText + + +class FeatureDefinition_Generic(FeatureDefinition): + + n_QUANTILES = 64 + + def __init__(self + , n_tfidf_node=None, t_ngrams_node=None, b_tfidf_node_lc=None + #, n_tfidf_edge=None, t_ngrams_edge=None, b_tfidf_edge_lc=None + ): + FeatureDefinition.__init__(self) + + self.n_tfidf_node, self.t_ngrams_node, self.b_tfidf_node_lc = n_tfidf_node, t_ngrams_node, b_tfidf_node_lc + # self.n_tfidf_edge, self.t_ngrams_edge, self.b_tfidf_edge_lc = n_tfidf_edge, t_ngrams_edge, b_tfidf_edge_lc + + tdifNodeTextVectorizer = TfidfVectorizer(lowercase=self.b_tfidf_node_lc + , max_features=self.n_tfidf_node + , analyzer = 'char' + , ngram_range=self.t_ngrams_node #(2,6) + , dtype=np.float64) + + node_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("text", Pipeline([ + ('selector', NodeTransformerTextEnclosed()), +# ('tfidf', TfidfVectorizer(lowercase=self.b_tfidf_node_lc, max_features=self.n_tfidf_node +# , analyzer = 'char', ngram_range=self.tNODE_NGRAMS #(2,6) +# , dtype=np.float64)), + ('tfidf', tdifNodeTextVectorizer), #we can use it separately from the pipleline once fitted + ('todense', SparseToDense()) #pystruct needs an array, not a sparse matrix + ]) + ) + , ("textlen", Pipeline([ + ('selector', NodeTransformerTextLen()), + ('textlen', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("xywh", Pipeline([ + ('selector', NodeTransformerXYWH()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("neighbors", Pipeline([ + ('selector', NodeTransformerNeighbors()), + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ]) + + lEdgeFeature = [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("boolean", Pipeline([ + ('boolean', EdgeBooleanAlignmentFeatures()) + ]) + ) + , ("numerical", Pipeline([ + ('selector', EdgeNumericalSelector_noText()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ] + + edge_transformer = FeatureUnion( lEdgeFeature ) + + #return _node_transformer, _edge_transformer, tdifNodeTextVectorizer + self._node_transformer = node_transformer + self._edge_transformer = edge_transformer + self.tfidfNodeTextVectorizer = None #tdifNodeTextVectorizer + + def cleanTransformers(self): + """ + the TFIDF transformers are keeping the stop words => huge pickled file!!! + + Here the fix is a bit rough. There are better ways.... + JL + """ + self._node_transformer.transformer_list[0][1].steps[1][1].stop_words_ = None #is 1st in the union... + +# if self.bMirrorPage: +# imax = 9 +# else: +# imax = 7 +# for i in range(3, imax): +# self._edge_transformer.transformer_list[i][1].steps[1][1].stop_words_ = None #are 3rd and 4th in the union.... + return self._node_transformer, self._edge_transformer + + diff --git a/TranskribusDU/graph/FeatureDefinition_Generic_noText.py b/TranskribusDU/graph/FeatureDefinition_Generic_noText.py new file mode 100644 index 0000000..12798c1 --- /dev/null +++ b/TranskribusDU/graph/FeatureDefinition_Generic_noText.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +""" + Standard PageXml features: + - not using the page information + - using a QuantileTransformer for numerical features instead of a StandardScaler + + No link with DOm or JSON => named GENERIC + + Copyright Xerox(C) 2016, 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +from sklearn.pipeline import Pipeline, FeatureUnion + +from .FeatureDefinition import FeatureDefinition +from .Transformer import EmptySafe_QuantileTransformer as QuantileTransformer + +from .Transformer_Generic import NodeTransformerXYWH +from .Transformer_Generic import NodeTransformerNeighbors +from .Transformer_Generic import EdgeBooleanAlignmentFeatures +from .Transformer_Generic import EdgeNumericalSelector_noText + + +class FeatureDefinition_Generic_noText(FeatureDefinition): + + n_QUANTILES = 16 + + def __init__(self): + FeatureDefinition.__init__(self) + + node_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("xywh", Pipeline([ + ('selector', NodeTransformerXYWH()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("neighbors", Pipeline([ + ('selector', NodeTransformerNeighbors()), + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ]) + + lEdgeFeature = [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("boolean", Pipeline([ + ('boolean', EdgeBooleanAlignmentFeatures()) + ]) + ) + , ("numerical", Pipeline([ + ('selector', EdgeNumericalSelector_noText()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ] + + edge_transformer = FeatureUnion( lEdgeFeature ) + + #return _node_transformer, _edge_transformer, tdifNodeTextVectorizer + self._node_transformer = node_transformer + self._edge_transformer = edge_transformer + self.tfidfNodeTextVectorizer = None #tdifNodeTextVectorizer + + +# def cleanTransformers(self): +# """ +# the TFIDF transformers are keeping the stop words => huge pickled file!!! +# +# Here the fix is a bit rough. There are better ways.... +# JL +# """ +# self._node_transformer.transformer_list[0][1].steps[1][1].stop_words_ = None #is 1st in the union... +# for i in [2, 3, 4, 5, 6, 7]: +# self._edge_transformer.transformer_list[i][1].steps[1][1].stop_words_ = None #are 3rd and 4th in the union.... +# return self._node_transformer, self._edge_transformer + + diff --git a/TranskribusDU/graph/FeatureDefinition_PageXml_FeatSelect.py b/TranskribusDU/graph/FeatureDefinition_PageXml_FeatSelect.py index d38d942..dd38118 100644 --- a/TranskribusDU/graph/FeatureDefinition_PageXml_FeatSelect.py +++ b/TranskribusDU/graph/FeatureDefinition_PageXml_FeatSelect.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/FeatureDefinition_PageXml_NoNodeFeat_v3.py b/TranskribusDU/graph/FeatureDefinition_PageXml_NoNodeFeat_v3.py index 65f250f..0673ecb 100644 --- a/TranskribusDU/graph/FeatureDefinition_PageXml_NoNodeFeat_v3.py +++ b/TranskribusDU/graph/FeatureDefinition_PageXml_NoNodeFeat_v3.py @@ -6,18 +6,7 @@ Copyright Xerox(C) 2017 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/FeatureDefinition_PageXml_logit.py b/TranskribusDU/graph/FeatureDefinition_PageXml_logit.py index 6ce31d8..0094024 100644 --- a/TranskribusDU/graph/FeatureDefinition_PageXml_logit.py +++ b/TranskribusDU/graph/FeatureDefinition_PageXml_logit.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """ - Logit-based PageXml features + Logit-based PageXml features, but using a QuantileTransformer for numerical features instead of a StandardScaler + After discussion with Stéphane Clinchant and Hervé Déjean, we will use the score of several logit multiclass classifiers instead of selecting ngrams. @@ -16,18 +17,7 @@ Copyright Xerox(C) 2017 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -41,17 +31,16 @@ from sklearn.pipeline import Pipeline, FeatureUnion -#not robust to empty arrays, so use our robust intermediary class instead -#from sklearn.preprocessing import StandardScaler -from .Transformer import RobustStandardScaler as StandardScaler -from .Transformer_PageXml import NodeTransformerTextLen, NodeTransformerXYWH, NodeTransformerNeighbors, Node1HotFeatures -from .Transformer_PageXml import Edge1HotFeatures, EdgeBooleanFeatures, EdgeNumericalSelector +from .Transformer import EmptySafe_QuantileTransformer as QuantileTransformer +from .Transformer_PageXml import NodeTransformerTextLen, NodeTransformerXYWH_v2, NodeTransformerNeighbors, Node1HotFeatures +from .Transformer_PageXml import Edge1HotFeatures, EdgeBooleanFeatures_v2, EdgeNumericalSelector from .PageNumberSimpleSequenciality import PageNumberSimpleSequenciality from .FeatureDefinition import FeatureDefinition from .Transformer_Logit import NodeTransformerLogit, EdgeTransformerLogit -class FeatureDefinition_PageXml_LogitExtractor(FeatureDefinition): +class FeatureDefinition_PageXml_LogitExtractor_v3(FeatureDefinition): + n_QUANTILES = 16 """ We will fit a logistic classifier @@ -60,8 +49,8 @@ def __init__(self, nbClass=None , n_feat_node=None, t_ngrams_node=None, b_node_lc=None , n_feat_edge=None, t_ngrams_edge=None, b_edge_lc=None , n_jobs=1): - FeatureDefinition.__init__(self, nbClass) - assert nbClass, "Error: indicate the numbe of classes" + FeatureDefinition.__init__(self) + assert nbClass, "Error: indicate the number of classes" self.nbClass = nbClass self.n_feat_node, self.t_ngrams_node, self.b_node_lc = n_feat_node, t_ngrams_node, b_node_lc self.n_feat_edge, self.t_ngrams_edge, self.b_edge_lc = n_feat_edge, t_ngrams_edge, b_edge_lc @@ -70,27 +59,34 @@ def __init__(self, nbClass=None # , analyzer = 'char', ngram_range=self.t_ngrams_node #(2,6) # , dtype=np.float64) """ - I tried to parallelize this code but I'm getting an error on Windows: - - File "c:\Local\meunier\git\TranskribusDU\src\crf\FeatureDefinition_PageXml_logit.py", line 144, in fitTranformers + - loading pre-computed data from: CV_5/model_A_fold_1_transf.pkl + no such file : CV_5/model_A_fold_1_transf.pkl +Traceback (most recent call last): + File "/opt/project/read/jl_git/TranskribusDU/src/tasks/DU_GTBooks_5labels.py", line 216, in + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm) + File "/opt/project/read/jl_git/TranskribusDU/src/tasks/DU_CRF_Task.py", line 481, in _nfold_RunFoldFromDisk + oReport = self._nfold_RunFold(iFold, ts_trn, lFilename_trn, train_index, test_index, bWarm=bWarm) + File "/opt/project/read/jl_git/TranskribusDU/src/tasks/DU_CRF_Task.py", line 565, in _nfold_RunFold + fe.fitTranformers(lGraph_trn) + File "/opt/project/read/jl_git/TranskribusDU/src/crf/FeatureDefinition_PageXml_logit_v2.py", line 141, in fitTranformers self._node_transformer.fit(lAllNode) - File "C:\Anaconda2\lib\site-packages\sklearn\pipeline.py", line 709, in fit + File "/opt/project/read/VIRTUALENV_PYTHON_FULL_type/lib/python2.7/site-packages/sklearn/pipeline.py", line 712, in fit for _, trans, _ in self._iter()) - File "C:\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 768, in __call__ + File "/opt/project/read/VIRTUALENV_PYTHON_FULL_type/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 768, in __call__ self.retrieve() - File "C:\Anaconda2\lib\site-packages\sklearn\externals\joblib\parallel.py", line 719, in retrieve + File "/opt/project/read/VIRTUALENV_PYTHON_FULL_type/lib/python2.7/site-packages/sklearn/externals/joblib/parallel.py", line 719, in retrieve raise exception -TypeError: can't pickle PyCapsule objects - -(virtual_python_pystruct) (C:\Anaconda2) c:\tmp_READ\tuto>python -c "import sklearn; print sklearn.__version__" -0.18.1 - => I force n_jobs to 1 - +RuntimeError: maximum recursion depth exceeded +""" """ - n_jobs = 1 - + I guess this is due to the cyclic links to node's neighbours. + But joblib.Parallel uses cPickle, so we cannot specialize the serialization of the Block objects. - n_jobs_NodeTransformerLogit = max(1, n_jobs/2) #half of the jobs for the NodeTransformerLogit, the rets for the others + JLM April 2017 + """ + n_jobs_from_graph = 1 #we cannot pickl the list of graph, so n_jobs = 1 for this part! +# n_jobs_NodeTransformerLogit = max(1, n_jobs/2) #half of the jobs for the NodeTransformerLogit, the rets for the others + n_jobs_NodeTransformerLogit = max(1, n_jobs - 1) #we keep a ref onto it because its fitting needs not only all the nodes, but also additional info, available on the graph objects self._node_transf_logit = NodeTransformerLogit(nbClass, self.n_feat_node, self.t_ngrams_node, self.b_node_lc, n_jobs=n_jobs_NodeTransformerLogit) @@ -100,17 +96,20 @@ def __init__(self, nbClass=None , ("textlen", Pipeline([ ('selector', NodeTransformerTextLen()), - ('textlen', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + #v2 ('textlen', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('textlen', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("xywh", Pipeline([ - ('selector', NodeTransformerXYWH()), - ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('selector', NodeTransformerXYWH_v2()), + #v2 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("neighbors", Pipeline([ ('selector', NodeTransformerNeighbors()), - ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + #v2 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("1hot", Pipeline([ @@ -130,7 +129,7 @@ def __init__(self, nbClass=None # #THIS ONE MUST BE LAST, because it include a placeholder column for the doculent-level tfidf # ]) # ) - ], n_jobs=max(1, n_jobs - n_jobs_NodeTransformerLogit)) + ], n_jobs=n_jobs_from_graph) lEdgeFeature = [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! ("1hot", Pipeline([ @@ -138,26 +137,24 @@ def __init__(self, nbClass=None ]) ) , ("boolean", Pipeline([ - ('boolean', EdgeBooleanFeatures()) + ('boolean', EdgeBooleanFeatures_v2()) ]) ) , ("numerical", Pipeline([ ('selector', EdgeNumericalSelector()), - ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + #v2 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("nodetext", EdgeTransformerLogit(nbClass, self._node_transf_logit)) ] - edge_transformer = FeatureUnion( lEdgeFeature, n_jobs=n_jobs ) + edge_transformer = FeatureUnion( lEdgeFeature, n_jobs=n_jobs_from_graph ) #return _node_transformer, _edge_transformer, tdifNodeTextVectorizer self._node_transformer = node_transformer self._edge_transformer = edge_transformer -# #dirty trick to enable testing the logit models -# self._node_transformer._testable_extractor_ = self._node_transf_logit - def fitTranformers(self, lGraph): """ Fit the transformers using the graphs @@ -194,3 +191,4 @@ def cleanTransformers(self): # self._edge_transformer.transformer_list[i][1].steps[1][1].stop_words_ = None #are 3rd and 4th in the union.... return self._node_transformer, self._edge_transformer + diff --git a/TranskribusDU/graph/FeatureDefinition_PageXml_std.py b/TranskribusDU/graph/FeatureDefinition_PageXml_std.py index 491972a..7de5c9e 100644 --- a/TranskribusDU/graph/FeatureDefinition_PageXml_std.py +++ b/TranskribusDU/graph/FeatureDefinition_PageXml_std.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -37,8 +26,8 @@ #from sklearn.preprocessing import StandardScaler from .Transformer import EmptySafe_QuantileTransformer as QuantileTransformer from .Transformer import SparseToDense -from .Transformer_PageXml import NodeTransformerXYWH_v2, NodeTransformerNeighbors, Node1HotFeatures -from .Transformer_PageXml import Edge1HotFeatures, EdgeBooleanFeatures_v2, EdgeNumericalSelector +from .Transformer_PageXml import NodeTransformerXYWH, NodeTransformerNeighbors, Node1HotFeatures +from .Transformer_PageXml import Edge1HotFeatures, EdgeBooleanFeatures, EdgeNumericalSelector from .Transformer_PageXml import NodeTransformerTextEnclosed, NodeTransformerTextLen from .Transformer_PageXml import EdgeTransformerSourceText, EdgeTransformerTargetText from .PageNumberSimpleSequenciality import PageNumberSimpleSequenciality @@ -79,7 +68,7 @@ def __init__(self, n_tfidf_node=None, t_ngrams_node=None, b_tfidf_node_lc=None ]) ) , ("xywh", Pipeline([ - ('selector', NodeTransformerXYWH_v2()), + ('selector', NodeTransformerXYWH()), #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) @@ -119,7 +108,7 @@ def __init__(self, n_tfidf_node=None, t_ngrams_node=None, b_tfidf_node_lc=None ]) ) , ("boolean", Pipeline([ - ('boolean', EdgeBooleanFeatures_v2()) + ('boolean', EdgeBooleanFeatures()) ]) ) , ("numerical", Pipeline([ diff --git a/TranskribusDU/graph/FeatureDefinition_PageXml_std_noText.py b/TranskribusDU/graph/FeatureDefinition_PageXml_std_noText.py index 42415d8..d4a582e 100644 --- a/TranskribusDU/graph/FeatureDefinition_PageXml_std_noText.py +++ b/TranskribusDU/graph/FeatureDefinition_PageXml_std_noText.py @@ -1,22 +1,11 @@ # -*- coding: utf-8 -*- """ - Standard PageXml features + Standard PageXml features, but using a QuantileTransformer for numerical features instead of a StandardScaler Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -32,16 +21,17 @@ #not robust to empty arrays, so use our robust intermediary class instead #from sklearn.preprocessing import StandardScaler -from .Transformer import RobustStandardScaler as StandardScaler +from .Transformer import EmptySafe_QuantileTransformer as QuantileTransformer from .Transformer import TransformerListByType -from .Transformer_PageXml import NodeTransformerXYWH_v2, NodeTransformerNeighbors, Node1HotFeatures -from .Transformer_PageXml import Edge1HotFeatures, EdgeBooleanFeatures_v2, EdgeNumericalSelector -from .PageNumberSimpleSequenciality import PageNumberSimpleSequenciality +from .Transformer_PageXml import NodeTransformerXYWH, NodeTransformerNeighbors, Node1HotFeatures_noText +from .Transformer_PageXml import EdgeBooleanFeatures, EdgeNumericalSelector_noText from .FeatureDefinition import FeatureDefinition class FeatureDefinition_PageXml_StandardOnes_noText(FeatureDefinition): + n_QUANTILES = 16 + def __init__(self): FeatureDefinition.__init__(self) @@ -54,33 +44,32 @@ def __init__(self): node_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! ("xywh", Pipeline([ - ('selector', NodeTransformerXYWH_v2()), - ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('selector', NodeTransformerXYWH()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("neighbors", Pipeline([ ('selector', NodeTransformerNeighbors()), - ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("1hot", Pipeline([ - ('1hot', Node1HotFeatures()) #does the 1-hot encoding directly + ('1hot', Node1HotFeatures_noText()) #does the 1-hot encoding directly ]) ) ]) lEdgeFeature = [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! - ("1hot", Pipeline([ - ('1hot', Edge1HotFeatures(PageNumberSimpleSequenciality())) - ]) - ) - , ("boolean", Pipeline([ - ('boolean', EdgeBooleanFeatures_v2()) + ("boolean", Pipeline([ + ('boolean', EdgeBooleanFeatures()) ]) ) , ("numerical", Pipeline([ - ('selector', EdgeNumericalSelector()), - ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('selector', EdgeNumericalSelector_noText()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) ] @@ -106,7 +95,7 @@ def __init__(self): # return self._node_transformer, self._edge_transformer -class FeatureDefinition_T_PageXml_StandardOnes_noText(FeatureDefinition): +class FeatureDefinition_T_PageXml_StandardOnes_noText_v4(FeatureDefinition): """ Multitype version: so the node_transformer actually is a list of node_transformer of length n_class @@ -114,41 +103,41 @@ class FeatureDefinition_T_PageXml_StandardOnes_noText(FeatureDefinition): We also inherit from FeatureDefinition_T !!! """ + n_QUANTILES = 16 def __init__(self, **kwargs): - FeatureDefinition.__init__(self) + FeatureDefinition.__init__(self, **kwargs) nbTypes = self._getTypeNumber(kwargs) node_transformer = TransformerListByType([ FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! ("xywh", Pipeline([ - ('selector', NodeTransformerXYWH_v2()), - ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('selector', NodeTransformerXYWH()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("neighbors", Pipeline([ ('selector', NodeTransformerNeighbors()), - ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) , ("1hot", Pipeline([ - ('1hot', Node1HotFeatures()) #does the 1-hot encoding directly + ('1hot', Node1HotFeatures_noText()) #does the 1-hot encoding directly ]) ) ]) for _i in range(nbTypes) ]) edge_transformer = TransformerListByType([ FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! - ("1hot", Pipeline([ - ('1hot', Edge1HotFeatures(PageNumberSimpleSequenciality())) - ]) - ) - , ("boolean", Pipeline([ - ('boolean', EdgeBooleanFeatures_v2()) + ("boolean", Pipeline([ + ('boolean', EdgeBooleanFeatures()) ]) ) , ("numerical", Pipeline([ - ('selector', EdgeNumericalSelector()), - ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('selector', EdgeNumericalSelector_noText()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling ]) ) ] ) for _i in range(nbTypes*nbTypes) ]) diff --git a/TranskribusDU/graph/FeatureDefinition_Standard.py b/TranskribusDU/graph/FeatureDefinition_Standard.py index aee2d10..9e7ef55 100644 --- a/TranskribusDU/graph/FeatureDefinition_Standard.py +++ b/TranskribusDU/graph/FeatureDefinition_Standard.py @@ -7,18 +7,7 @@ Copyright NAVER(C) 2019 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -250,7 +239,7 @@ class Selector(Transformer): a range of features is dedicated for each class of edges (depends on bMultiPage). (Vertical, horizontal, cross-page, ...) """ - N = 12 + N = 21 def transform(self, lEdge): a = np.zeros( ( len(lEdge), self.N ) , dtype=np.float64) for i, edge in enumerate(lEdge): @@ -261,8 +250,12 @@ def transform(self, lEdge): l_nv = l / float(edge.A.page.h) l_nh = l / float(edge.A.page.w) # normalized horizontally - # overlap + # overlap due to masking ovrl = edge.overlap + # overlap ignoring masking + # ovrl_max = edge.computeOverlap() # new 9/8/19 + ovrl_max, pA, pB = edge.computeOverlapPosition() # new 8/8/19 + r_ovrl = (ovrl+0.001) / (0.001+ovrl_max) # avoid zero div. # IoU iou = edge.iou @@ -275,8 +268,13 @@ def transform(self, lEdge): , l_nh , l_nh*l_nh , l_nv , l_nv*l_nv , ovrl , ovrl*ovrl + , ovrl_max , ovrl_max*ovrl_max # new 8/8/19 , iou , iou*iou + , r_ovrl , r_ovrl*r_ovrl # new 8/8/19 , space , space*space + , r_ovrl / max(l,1) # new 8/8/19 + , pA , pA*pA # new 9/8/19 + , pB , pB*pB # new 9/8/19 ) return a diff --git a/TranskribusDU/graph/Graph.py b/TranskribusDU/graph/Graph.py index a36847a..72b67bd 100644 --- a/TranskribusDU/graph/Graph.py +++ b/TranskribusDU/graph/Graph.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -25,20 +14,14 @@ """ - - - import collections -#import gc import numpy as np -from lxml import etree from common.trace import traceln -from xml_formats.PageXml import PageXmlException from . import Edge -from xml_formats.PageXml import PageXml + class GraphException(Exception): pass @@ -55,61 +38,71 @@ class Graph: _nbLabelTot = 0 #total number of labels iGraphMode = 1 # how to compute edges 1=historical 2=dealing properly with line-of-sight - + + bConjugate = False + + sIN_FORMAT = "undefined" # tell here which input format is expected, subclasses must specialise this + sDU = "_du" + sOUTPUT_EXT = ".out" + #--- CONSTRAINTS _lPageConstraintDef = None #optionnal page-level constraints - # do we use the conjugate graph? - bConjugate = False - def __init__(self, lNode = [], lEdge = []): self.lNode = lNode self.lEdge = lEdge - self.doc = None + self.doc = None # the document object (can be a DOM, or a JSON, or...) + self.lCluster = None # list of clusters (resulting from a segmentation task) self.aNeighborClassMask = None #did we compute the neighbor class mask already? self.aLabelCount = None #count of seen labels - - # --- CONJUGATE ----------------------------------------------------- + + # --- I/O @classmethod - def setConjugateMode(cls - , lEdgeLabel = None # list of labels (list of strings, or of int) - , computeEdgeLabels = None # to compute the edge labels - , exploitEdgeLabels = None # to use the predicted edge labels - ): - """ - learn and predict on the conjugate graph instead of the usual graph. - 1 - The usual graph is created as always - 2 - the function computeEdgeLabels is called to compute the edge labels - 3 - the conjugate is created and used for learning or predicting - 4 - the function exploitEdgeLabels is called once the edge labels are predicted - - The prototype of the functions are as shown in the code below. + def isOutputFilename(cls, sFilename): + return sFilename.endswith(cls.sDU+cls.sOUTPUT_EXT) + + @classmethod + def getOutputFilename(cls, sFilename): + return sFilename[:sFilename.rindex(".")] + cls.sDU + cls.sOUTPUT_EXT + + + def parseDocFile(self, sFilename, iVerbose=0): + """ + Load that document as a Graph. + Also set the self.doc variable! - NOTE: since the graph class may already be dedicated to the Conjugate mode, - we do not force to pass the list of edge labels and the 2 methods - But bConjugate must be True already! - """ - if cls.bConjugate is True: - # then we accept to override some stuff... - if not lEdgeLabel is None: cls.lEdgeLabel = lEdgeLabel - if not computeEdgeLabels is None: cls.computeEdgeLabels = computeEdgeLabels - if not exploitEdgeLabels is None: cls.exploitEdgeLabels = exploitEdgeLabels - else: - if None in [lEdgeLabel, computeEdgeLabels, exploitEdgeLabels]: - raise GraphException("You must provide lEdgeLabel, computeEdgeLabels, and exploitEdgeLabels") - cls.bConjugate = True - cls.lEdgeLabel = lEdgeLabel - cls.computeEdgeLabels = computeEdgeLabels - cls.exploitEdgeLabels = exploitEdgeLabels + Return a Graph object + """ + raise GraphException("You must specialise this class method") + + @classmethod + def getDocInputFormat(cls): + """ + return a human-readable string describing the expected input format + """ + return cls.sIN_FORMAT + + def detachFromDoc(self): + """ + Graph and graph nodes may have kept a reference to other data . + Here we detach them + """ + return - assert len(cls.lEdgeLabel) > 1, ("Invalid number of edge labels (graph conjugate mode)", lEdgeLabel) - traceln("SETUP: Conjugate mode: %s" % str(cls)) - - return cls.bConjugate + @classmethod + def saveDoc(cls, sFilename, doc, lG, sCreator, sComment): + """ + sFile is the input filename + doc is the input data (DOM, or JSON for now) possibly enriched by the + prediction, depending on the class of graph + lG is the list of graphs, possibly enriched by the prediction + """ + raise GraphException("You must specialise this class method") + # --- Graph COnstruction Mode @classmethod def getGraphMode(cls): return cls.iGraphMode @@ -134,19 +127,14 @@ def computeEdgeLabels(self): return the set of observed class (set of integers in N+) """ - raise GraphException("You must specialize this class method") + raise GraphException("You must specialise this class method") - def exploitEdgeLabels(self, Y_proba): + def form_cluster(self, Y_proba): """ Do whatever is required on the (primal) graph, given the edge labels Y_proba is the predicted edge label probability array of shape (N_edges, N_labels) - return None - - The node and edge indices corresponding to the order of the lNode - and lEdge attribute of the graph object. - - Here we choose to set an XML attribute DU_cluster="" + return a ClusterList of Cluster object """ raise GraphException("You must specialize this class method") @@ -205,12 +193,9 @@ def getLabelNameList(cls): """ Return the list of label names for all label sets """ -# if cls.bConjugate: -# return cls.getEdgeLabelNameList() -# else: return [sLabelName for lblSet in cls._lNodeType for sLabelName in lblSet.getLabelNameList()] - def parseDomLabels(self): + def parseDocLabels(self): """ Parse the label of the graph from the dataset, and set the node label return the set of observed class (set of integers in N+) @@ -219,10 +204,7 @@ def parseDomLabels(self): for nd in self.lNode: nodeType = nd.type #a LabelSet object knows how to parse a DOM node of a Graph object!! - try: - sLabel = nodeType.parseDomNodeLabel(nd.node) - except PageXmlException: - sLabel='TR_OTHER' + sLabel = nodeType.parseDocNodeLabel(nd) try: cls = self._dClsByLabel[sLabel] #Here, if a node is not labelled, and no default label is set, then KeyError!!! except KeyError: @@ -231,16 +213,15 @@ def parseDomLabels(self): setSeensLabels.add(cls) return setSeensLabels - def setDomLabels(self, Y): + def setDocLabels(self, Y): """ Set the labels of the graph nodes from the Y matrix - return the DOM """ for i,nd in enumerate(self.lNode): sLabel = self._dLabelByCls[ Y[i] ] - nd.type.setDomNodeLabel(nd.node, sLabel) - return self.doc - + nd.type.setDocNodeLabel(nd, sLabel) + return + # --- Constraints ----------------------------------------------------------- def setPageConstraint(cls, lPageConstraintDef): """ @@ -287,11 +268,10 @@ def loadGraphs(cls , cGraphClass # graph class (must be subclass) , lsFilename , bNeighbourhood=True # incident edges for each node, by type of edge - , bDetach=False # keep or free the DOM + , bDetach=False # keep or free the source data , bLabelled=False # do we read node labels? , iVerbose=0 , attachEdge=False # all incident edges for each node - , bConjugate=False # Conjugate mode ): """ Load one graph per file, and detach its DOM @@ -301,54 +281,42 @@ def loadGraphs(cls for sFilename in lsFilename: if iVerbose: traceln("\t%s"%sFilename) g = cGraphClass() - g.parseXmlFile(sFilename, iVerbose) + g.parseDocFile(sFilename, iVerbose) + g._index() if not g.isEmpty(): if attachEdge and bNeighbourhood: g.collectNeighbors(attachEdge=attachEdge) if bNeighbourhood: g.collectNeighbors() - if bLabelled: g.parseDomLabels() - if bDetach: g.detachFromDOM() - lGraph.append(g) + if bLabelled: g.parseDocLabels() + if bDetach: g.detachFromDoc() + lGraph.append(g) return lGraph - def parseXmlFile(self, sFilename, iVerbose=0): + @classmethod + def castGraphList(cls + , cGraphClass # graph class (must be subclass) + , lGraph + , iVerbose=0 + ): """ - Load that document as a CRF Graph. - Also set the self.doc variable! - - Return a CRF Graph object + Here we create an instance of graph that reuses the lists of nodes and edge from another graph """ - - self.doc = etree.parse(sFilename) - self.lNode, self.lEdge = list(), list() - #load the block of each page, keeping the list of blocks of previous page - lPrevPageNode = None - - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): - #now that we have the page, let's create the node for each type! - lPageNode = list() - setPageNdDomId = set() #the set of DOM id - # because the node types are supposed to have an empty intersection - - lPageNode = [nd for nodeType in self.getNodeTypeList() for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] - - #check that each node appears once - setPageNdDomId = set([nd.domid for nd in lPageNode]) - assert len(setPageNdDomId) == len(lPageNode), "ERROR: some nodes fit with multiple NodeTypes" - - - self.lNode.extend(lPageNode) - - lPageEdge = Edge.Edge.computeEdges(lPrevPageNode, lPageNode, self.iGraphMode) - - self.lEdge.extend(lPageEdge) - if iVerbose>=2: traceln("\tPage %5d %6d nodes %7d edges"%(pnum, len(lPageNode), len(lPageEdge))) - - lPrevPageNode = lPageNode - if iVerbose: traceln("\t\t (%d nodes, %d edges)"%(len(self.lNode), len(self.lEdge)) ) - - return self + assert len(cGraphClass.getNodeTypeList()) == 1 + new_ndType = cGraphClass.getNodeTypeList()[0] + lNewGraph = [] + for g in lGraph: + new_g = cGraphClass() + new_g.doc = g.doc - def _iter_Page_DomNode(self, doc): + new_g.lNode = g.lNode + # we need to change the node type of all nodes... + # I always knew it was bad to have one type attribute on each node in signle-type graphs... + for _nd in g.lNode: _nd.type = new_ndType + + new_g.lEdge = g.lEdge + lNewGraph.append(new_g) + return lNewGraph + + def _iter_Page_DocNode(self, doc): """ Parse a Xml DOM, by page @@ -358,8 +326,8 @@ def _iter_Page_DomNode(self, doc): """ raise Exception("Must be specialized") - def isEmpty(self): return self.lNode == [] - + def isEmpty(self): + return self.lNode == [] def collectNeighbors(self,attachEdge=False): """ @@ -415,57 +383,13 @@ def getNeighborClassMask(self): return self.aNeighborClassMask - def detachFromDOM(self): - """ - Detach the graph from the DOM node, which can then be freed - """ - if self.doc != None: - for nd in self.lNode: nd.detachFromDOM() - self.doc = None - #gc.collect() - def revertEdges(self): """ revert the direction of each edge of the graph """ for e in self.lEdge: e.revertDirection() - def addEdgeToDOM(self, Y=None): - """ - To display the graph conveniently we add new Edge elements - """ - ndPage = self.lNode[0].page.node - # w = int(ndPage.get("imageWidth")) - ndPage.append(etree.Comment("Edges added to the XML for convenience")) - for edge in self.lEdge: - A, B = edge.A , edge.B #shape.centroid, edge.B.shape.centroid - ndEdge = PageXml.createPageXmlNode("Edge") - ndEdge.set("src", edge.A.node.get("id")) - ndEdge.set("tgt", edge.B.node.get("id")) - ndEdge.set("type", edge.__class__.__name__) - ndEdge.tail = "\n" - ndPage.append(ndEdge) - PageXml.setPoints(ndEdge, [(A.x1, A.y1), (B.x1, B.y1)]) - - return - # --- Numpy matrices -------------------------------------------------------- - def getXY(self, node_transformer, edge_transformer): - """ - return a tuple (X,Y) for the graph (X is a triplet) - """ - self._index() - - if self._bMultitype: - if self.bConjugate: - raise "Not yet implemented: conjugate of multitype graph" - X, Y = self._buildNodeEdgeLabelMatrices_T(node_transformer, edge_transformer, bY=True) - else: - X, Y = ( self.getX(node_transformer, edge_transformer) - , self.getY() ) - - return (X, Y) - def getX(self, node_transformer, edge_transformer): """ make 1 node-feature matrix (or list of matrices for multitype graphs) @@ -478,68 +402,18 @@ def getX(self, node_transformer, edge_transformer): """ self._index() if self._bMultitype: - if self.bConjugate: raise "Not yet implemented: conjugate of multitype graph" X = self._buildNodeEdgeLabelMatrices_T(node_transformer, edge_transformer, bY=False) else: X = self._buildNodeEdgeMatrices_S(node_transformer, edge_transformer) - if self.bConjugate: - X = self.convert_X_to_LineDual(X) return X def getY(self): """ WARNING, in multitype graphs, the order of the Ys is bad """ - if self.bConjugate: - Y = self._buildLabelMatrix_S_Y() - else: - Y = self._buildLabelMatrix_S() + Y = np.fromiter( (nd.cls for nd in self.lNode), dtype=np.int, count=len(self.lNode)) return Y - # --- Conjugate -------------------------------------------------------- - def convert_X_to_LineDual(self, X): - """ - Convert to a dual graph - Animesh 2018 - Revisited by JL April 2019 - - NOTE: isolated nodes are not reflected in the dual. - Should we add a reflexive edge to have the node in the dual?? - """ - (nf, edge, ef) = X - - nb_edge = edge.shape[0] - - all_edges = [] # all edges created so far - - nf_dual = ef # edges become nodes - edge_dual = [] - ef_dual = [] - - for i in range(nb_edge): - edgei_from_idx, edgei_to_idx = edge[i] - - edge_from = set([edgei_from_idx, edgei_to_idx]) - for j in range(i+1, nb_edge): - edge_to = set([edge[j][0], edge[j][1]]) - edge_candidate = edge_from.symmetric_difference(edge_to) - # we should get 4, 2 or 0 primal nodes - if len(edge_candidate) == 2 and edge_candidate not in all_edges: - # edge_to and edge_from share 1 primal node => create dual edge! - all_edges.append(edge_candidate) - [shared_node_idx] = edge_from.intersection(edge_to) - shared_node_nf = nf[shared_node_idx] - ef_dual.append(shared_node_nf) - edge_dual.append([i, j]) - - nf_dual = np.array(nf_dual) - edge_dual = np.array(edge_dual) - ef_dual = np.array(ef_dual) - - assert (edge_dual.shape[0] == ef_dual.shape[0]) - - return (nf_dual, edge_dual, ef_dual) - #----- Indexing Graph Objects ----- def _index(self, bForce=False): """ @@ -551,17 +425,11 @@ def _index(self, bForce=False): bForce or self.__bNodeIndexed return False except AttributeError: - self._indexNodeTypes() - for i, nd in enumerate(self.lNode): nd._index = i + for i, nt in enumerate(self._lNodeType): nt._index = i + for i, nd in enumerate(self.lNode) : nd._index = i self.__bNodeIndexed = True return True - def _indexNodeTypes(self): - """ - add _index attribute to registered NodeType - """ - for i, nt in enumerate(self._lNodeType): nt._index = i - #----- SINGLE TYPE ----- def _buildNodeEdgeMatrices_S(self, node_transformer, edge_transformer): """ @@ -593,24 +461,6 @@ def _BuildEdgeMatrix_S(self): # edge = edges.reshape(len(self.lEdge), 2) return edges - def _buildLabelMatrix_S(self): - """ - Return the matrix of labels - """ - #better code based on fromiter is below (I think, JLM April 2017) - #Y = np.array( [nd.cls for nd in self.lNode] , dtype=np.uint8) - Y = np.fromiter( (nd.cls for nd in self.lNode), dtype=np.int, count=len(self.lNode)) - return Y - - def _buildLabelMatrix_S_Y(self): - """ - Return the matrix of labels of edges - """ - #better code based on fromiter is below (I think, JLM April 2017) - #Y = np.array( [nd.cls for nd in self.lNode] , dtype=np.uint8) - Y = np.fromiter( (e.cls for e in self.lEdge), dtype=np.int, count=len(self.lEdge)) - return Y - #----- MULTITYPE ----- def _buildNodeEdgeLabelMatrices_T(self, node_transformer, edge_transformer, bY=True): """ @@ -688,10 +538,4 @@ def getNodeIndexByPage(self): for pnum in sorted(dlIndexByPage.keys()): llIndexByPage.append( sorted(dlIndexByPage[pnum]) ) return llIndexByPage - - - - - - diff --git a/TranskribusDU/graph/GraphConjugate.py b/TranskribusDU/graph/GraphConjugate.py new file mode 100644 index 0000000..18a8651 --- /dev/null +++ b/TranskribusDU/graph/GraphConjugate.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- + +""" + Conjugate Graph Class + + Copyright NAVER(C) 2019 JL. Meunier +""" +import numpy as np + +from graph.Graph import Graph + +class GraphConjugate(Graph): + + bConjugate = True + + def __init__(self, lNode = [], lEdge = []): + super(GraphConjugate, self).__init__(lNode, lEdge) + + @classmethod + def getLabelNameList(cls): + """ + We work on edges! + """ + return cls.getEdgeLabelNameList() + + def setDocLabels(self, Y): + """ + Set the labels of the graph nodes from the Y matrix + Is this used anywhere??? + """ + # lEdgeLabel = self.getEdgeLabelNameList() + for i,edge in enumerate(self.lEdge): + edge.cls = Y[i] # might be a Numpy vector! (argmax not yet done) + return + + # --- Numpy matrices -------------------------------------------------------- + def getX(self, node_transformer, edge_transformer): + """ + make 1 node-feature matrix (or list of matrices for multitype graphs) + and 1 edge-feature matrix (or list of matrices for multitype graphs) + and 1 edge matrix (or list of matrices for multitype graphs) + for the graph + return a triplet + + return X for the graph + """ + X = Graph.getX(self, node_transformer, edge_transformer) + X = self.convert_X_to_LineDual(X) + return X + + def getY(self): + """ + WARNING, in multitype graphs, the order of the Ys is bad + """ + Y = np.fromiter( (e.cls for e in self.lEdge), dtype=np.int, count=len(self.lEdge)) + return Y + + # --- Conjugate -------------------------------------------------------- + def convert_X_to_LineDual(self, X): + """ + Convert to a dual graph + Animesh 2018 + Revisited by JL April 2019 + + NOTE: isolated nodes are not reflected in the dual. + Should we add a reflexive edge to have the node in the dual?? + """ + if self._bMultitype: + raise "Not yet implemented: conjugate of multitype graph" + (nf, edge, ef) = X + + nb_edge = edge.shape[0] + + all_edges = [] # all edges created so far + + nf_dual = ef # edges become nodes + edge_dual = [] + ef_dual = [] + + for i in range(nb_edge): + edgei_from_idx, edgei_to_idx = edge[i] + + edge_from = set([edgei_from_idx, edgei_to_idx]) + for j in range(i+1, nb_edge): + edge_to = set([edge[j][0], edge[j][1]]) + edge_candidate = edge_from.symmetric_difference(edge_to) + # we should get 4, 2 or 0 primal nodes + if len(edge_candidate) == 2 and edge_candidate not in all_edges: + # edge_to and edge_from share 1 primal node => create dual edge! + all_edges.append(edge_candidate) + [shared_node_idx] = edge_from.intersection(edge_to) + shared_node_nf = nf[shared_node_idx] + ef_dual.append(shared_node_nf) + edge_dual.append([i, j]) + + nf_dual = np.array(nf_dual) + edge_dual = np.array(edge_dual) + ef_dual = np.array(ef_dual) + + assert (edge_dual.shape[0] == ef_dual.shape[0]) + + return (nf_dual, edge_dual, ef_dual) diff --git a/TranskribusDU/graph/GraphConjugateSegmenter.py b/TranskribusDU/graph/GraphConjugateSegmenter.py new file mode 100644 index 0000000..0b3027b --- /dev/null +++ b/TranskribusDU/graph/GraphConjugateSegmenter.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +""" + Conjugate Graph Class used for Segmentation + + Copyright NAVER(C) 2019 JL. Meunier +""" + + +from common.trace import traceln + +from graph.GraphConjugate import GraphConjugate + + +class GraphConjugateSegmenter(GraphConjugate): + + bConjugate = True + + def __init__(self, lNode = [], lEdge = []): + super(GraphConjugateSegmenter, self).__init__(lNode, lEdge) + + # --- Clusters ------------------------------------------------------- + def addClusterToDoc(self, lCluster): + """ + Do whatever is required with the ClusterList object + """ + raise Exception("To be specialised!") + + def setDocLabels(self, Y): + """ + Set the labels of the graph nodes from the Y matrix + Is this used anywhere??? + """ + GraphConjugate.setDocLabels(self, Y) + self.lCluster = self.form_cluster(Y) + + self.addClusterToDoc(self.lCluster) + traceln(" %d cluster(s) found" % (len(self.lCluster))) diff --git a/TranskribusDU/graph/GraphModel.py b/TranskribusDU/graph/GraphModel.py index 4f4273a..0599e24 100644 --- a/TranskribusDU/graph/GraphModel.py +++ b/TranskribusDU/graph/GraphModel.py @@ -7,18 +7,7 @@ Copyright NAVER(C) 2016-2019 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -51,14 +40,17 @@ class GraphModelException(Exception): """ pass +class GraphModelNoEdgeException(Exception): + """ + Exception specific to this class: absence of edge in the graph + """ + pass + class GraphModel: _balancedWeights = False # Uniform does same or better, in general - # conjugate mode - bConjugate = False - sSurname = "" #surname is added to each generated filename, e.g. crf, ecn, ... def __init__(self, sName, sModelDir): @@ -79,7 +71,6 @@ def __init__(self, sName, sModelDir): self.bTrainEdgeBaseline = False self._nbClass = None - def configureLearner(self, **kwargs): """ @@ -87,44 +78,37 @@ def configureLearner(self, **kwargs): """ raise Exception("Method must be overridden") - # --- Conjugate ------------------------------------------------------------- - @classmethod - def setConjugateMode(cls): - cls.bConjugate = True - traceln("SETUP: Conjugate mode: %s" % str(cls)) - return cls.bConjugate + def setName(self, sName): + self.sName = sName # --- Utilities ------------------------------------------------------------- + def getMetadataComment(self): + """ + Return an informative short string for storing a metadata comment in output XML + """ + return "%s: %s (%s)" % (self.__class__.__name__, self.sName, os.path.abspath(self.sDir)) + + def getModelFilename(self): - return os.path.join(self.sDir, self.sName+'.'+self.sSurname+".model.pkl") + return os.path.join(self.sDir, self.sName+'._.'+self.sSurname+".model.pkl") def getTransformerFilename(self): - return os.path.join(self.sDir, self.sName+ ".transf.pkl") + return os.path.join(self.sDir, self.sName+'._.'+ ".transf.pkl") def getConfigurationFilename(self): - return os.path.join(self.sDir, self.sName+'.'+self.sSurname+".config.json") + return os.path.join(self.sDir, self.sName+'._.'+self.sSurname+".config.json") def getBaselineFilename(self): - return os.path.join(self.sDir, self.sName+'.'+self.sSurname+".baselines.pkl") + return os.path.join(self.sDir, self.sName+'._.'+self.sSurname+".baselines.pkl") def getTrainDataFilename(self, name): - return os.path.join(self.sDir, self.sName+'.'+self.sSurname+".tlXlY_%s.pkl"%name) + return os.path.join(self.sDir, self.sName+'._.'+self.sSurname+".tlXlY_%s.pkl"%name) @classmethod def _getParamsFilename(cls, sDir, sName): - return os.path.join(sDir, sName+"_params.json") + return os.path.join(sDir, sName+"._."+"_params.json") - """ - When some class is not represented on some graph, you must specify the number of class. - Otherwise pystruct will complain about the number of states differeing from the number of weights - """ def setNbClass(self, lNbClass): """ in multitype case we get a list of class number (one per type) """ self._nbClass = lNbClass - - def getNbClass(self): - if self.bConjugate: - return len(self.lEdgeLabel) - else: - return self._nbClass def _getNbFeatureAsText(self): """ diff --git a/TranskribusDU/graph/Graph_DOM.py b/TranskribusDU/graph/Graph_DOM.py new file mode 100644 index 0000000..3134d6b --- /dev/null +++ b/TranskribusDU/graph/Graph_DOM.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- + +""" + Computing the graph for a XML document + + Copyright Xerox(C) 2016, 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import math + +from lxml import etree +import shapely.geometry as geom + +from common.trace import traceln +from util import XYcut +from .Graph import Graph +from . import Edge +from xml_formats.PageXml import PageXml, MultiPageXml + + +class Graph_DOM(Graph): + """ + Graph for DOM input + """ + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + sIN_FORMAT = "XML" # tell here which input format is expected + sOUTPUT_EXT = ".mpxml" + + def __init__(self, lNode = [], lEdge = []): + Graph.__init__(self, lNode, lEdge) + + def parseDocFile(self, sFilename, iVerbose=0): + """ + Load that document as a CRF Graph. + Also set the self.doc variable! + + Return a CRF Graph object + """ + + self.doc = etree.parse(sFilename) + self.lNode, self.lEdge = list(), list() + #load the block of each page, keeping the list of blocks of previous page + lPrevPageNode = None + + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): + #now that we have the page, let's create the node for each type! + lPageNode = list() + setPageNdDomId = set() #the set of DOM id + # because the node types are supposed to have an empty intersection + + lPageNode = [nd for nodeType in self.getNodeTypeList() for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] + + #check that each node appears once + setPageNdDomId = set([nd.domid for nd in lPageNode]) + assert len(setPageNdDomId) == len(lPageNode), "ERROR: some nodes fit with multiple NodeTypes" + + + self.lNode.extend(lPageNode) + + lPageEdge = Edge.Edge.computeEdges(lPrevPageNode, lPageNode, self.iGraphMode) + + self.lEdge.extend(lPageEdge) + if iVerbose>=2: traceln("\tPage %5d %6d nodes %7d edges"%(pnum, len(lPageNode), len(lPageEdge))) + + lPrevPageNode = lPageNode + if iVerbose: traceln("\t\t (%d nodes, %d edges)"%(len(self.lNode), len(self.lEdge)) ) + + return self + + def addEdgeToDoc(self, Y=None, ndPage=None): + """ + To display the graph conveniently we add new Edge elements + """ + assert Y == None + if self.lNode: + ndPage = self.lNode[0].page.node if ndPage is None else ndPage + ndPage.append(etree.Comment("Edges added to the XML for convenience")) + for edge in self.lEdge: + A, B = edge.A , edge.B #shape.centroid, edge.B.shape.centroid + ndEdge = PageXml.createPageXmlNode("Edge") + ndEdge.set("src", edge.A.node.get("id")) + ndEdge.set("tgt", edge.B.node.get("id")) + ndEdge.set("type", edge.__class__.__name__) + ndEdge.tail = "\n" + ndPage.append(ndEdge) + PageXml.setPoints(ndEdge, [(A.x1, A.y1), (B.x1, B.y1)]) + + return + + @classmethod + def saveDoc(cls, sFilename, doc, _lg, sCreator, sComment): + """ + _lg is not used since we have enriched the DOC (in doc parameter) + """ + # build a decent output filename + sDUFilename = cls.getOutputFilename(sFilename) + + MultiPageXml.setMetadata(doc, None, sCreator, sComment) + + doc.write(sDUFilename, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + return sDUFilename + + def detachFromDoc(self): + """ + Detach the graph from the DOM node, which can then be freed + """ + if self.doc != None: + for nd in self.lNode: nd.detachFromDOM() + self.doc = None + + @classmethod + def exportToDom(cls, lg, bGraph=False): + """ + export a set of graph as (Multi)PageXml + """ + #create document + pageDoc,pageNode = PageXml.createPageXmlDocument('graph2DOM', filename="",imgW=0, imgH=0) + pageW, pageH = 0,0 + for iG,g in enumerate(lg): + if iG > 0: + pageNode = PageXml.createPageXmlNode("Page") + pageDoc.getroot().append(pageNode) + +# lRegions = g.lCluster +# for region in lRegions: +# lNodes = [g.lNode[idx] for idx in region] + + if g.lCluster: + lRegionNodes = [ [g.lNode[idx] for idx in region] for region in g.lCluster ] + else: + lRegionNodes = [ g.lNode ] + for iR, lNodes in enumerate(lRegionNodes): + #regionType = region.type + regionNode = PageXml.createPageXmlNode("TextRegion") + regionNode.set("id", "R%d"%iR) + pageNode.append(regionNode) + + #regionNode.set("type",str(regionType)) + # Region geometry + mp = geom.MultiPolygon(b.getShape() for b in lNodes) + xmin, ymin, xmax, ymax = geom.MultiPolygon(b.getShape() for b in lNodes).bounds + coordsNode = PageXml.createPageXmlNode('Coords') + regionNode.append(coordsNode) + PageXml.setPoints(coordsNode,[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)]) + + # update page dimension + pageW = max(pageW, xmax) + pageH = max(pageH, ymax) + + lSegment = [(o.y1, o.y2, o) for o in lNodes] + lLines,_,_ = XYcut.mergeSegments(lSegment, 2) + for iL, (ymin,ymax,lw) in enumerate(lLines): + xmin, ymin, xmax, ymax = geom.MultiPolygon(b.getShape() for b in lw).bounds + textLineNode = PageXml.createPageXmlNode('TextLine') + textLineNode.set("id", "R%d_L%d" % (iR, iL)) + + coordsNode = PageXml.createPageXmlNode('Coords') + textLineNode.append(coordsNode) + PageXml.setPoints(coordsNode,[(xmin,ymin),(xmax,ymin),(xmax,ymax),(xmin,ymax)]) + regionNode.append(textLineNode) + + for iW, w in enumerate(lw): + wordNode = PageXml.createPageXmlNode('Word') + # standard Block attributes for DOM objects, so that addEdgeTODoc wan work + w.node = wordNode + w.domid = "R%d_L%d_W%d" % (iR, iL, iW) + wordNode.set("id", w.domid) + textLineNode.append(wordNode) + coordsNode = PageXml.createPageXmlNode('Coords') + wordNode.append(coordsNode) + PageXml.setPoints(coordsNode,w.shape.exterior.coords) + + textEquiv= PageXml.createPageXmlNode('TextEquiv') + wordNode.append(textEquiv) + unicode = PageXml.createPageXmlNode('Unicode') + textEquiv.append(unicode) + unicode.text=w.text + # end of page + pageNode.set('imageWidth' , str(math.ceil(pageW))) + pageNode.set('imageHeight', str(math.ceil(pageH))) + pageW, pageH = 0,0 + + if bGraph: + g.addEdgeToDoc(g, ndPage=pageNode) + + return pageDoc \ No newline at end of file diff --git a/TranskribusDU/graph/Graph_DSXml.py b/TranskribusDU/graph/Graph_DSXml.py index b1f415d..2fcd286 100644 --- a/TranskribusDU/graph/Graph_DSXml.py +++ b/TranskribusDU/graph/Graph_DSXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -28,31 +17,37 @@ -from .Graph import Graph +from .Graph_DOM import Graph_DOM from .Page import Page -class Graph_DSXml(Graph): +class Graph_DSXml(Graph_DOM): ''' Computing the graph for a MultiPageXml document - - USAGE: - - call parseFile to load the DOM and create the nodes and edges - - call detachFromDOM before freeing the DOM ''' + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + #Namespace, of PageXml, at least dNS = {} #How to list the pages of a (Multi)PageXml doc sxpPage = "//PAGE" + + sIN_FORMAT = "DS_XML" # tell here which input format is expected + sOUTPUT_EXT = ".ds.xml" sXmlFilenamePattern = "*_ds.xml" #how to find the DS XML files def __init__(self, lNode = [], lEdge = []): - Graph.__init__(self, lNode, lEdge) + Graph_DOM.__init__(self, lNode, lEdge) # --------------------------------------------------------------------------------------------------------- - def _iter_Page_DomNode(self, doc): + def _iter_Page_DocNode(self, doc): """ Parse a Multi-pageXml DOM, by page @@ -73,5 +68,5 @@ def _iter_Page_DomNode(self, doc): page = Page(pnum, pagecnt, iPageWidth, iPageHeight, cls=None, domnode=ndPage, domid=ndPage.fet("id")) yield (pnum, page, ndPage) - raise StopIteration() + return diff --git a/TranskribusDU/graph/Graph_JsonOCR.py b/TranskribusDU/graph/Graph_JsonOCR.py new file mode 100644 index 0000000..9bedc6b --- /dev/null +++ b/TranskribusDU/graph/Graph_JsonOCR.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +""" + Computing the graph for a json file, which is the output of OCR + @author: Nitin Choudhary + +""" +import io +import json + +from common.trace import traceln + +from .Graph import Graph +from . import Edge + + +class Graph_JsonOCR(Graph): + ''' + Computing the graph for a json file + + ''' + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + sIN_FORMAT = "JSON_OCR" # tell here which input format is expected + sOUTPUT_EXT = ".json" + + def __init__(self, lNode=[], lEdge=[]): + Graph.__init__(self, lNode, lEdge) + + @classmethod + def loadGraphs(cls + , cGraphClass # graph class (must be subclass) + , lsFilename + , bNeighbourhood=True + , bDetach=False + , bLabelled=False + , iVerbose=0 + , attachEdge=False # all incident edges for each node + ): + """ + Load one graph per file, and detach its DOM + return the list of loaded graphs + """ + lGraph = [] + for sFilename in lsFilename: + if iVerbose: traceln("\t%s" % sFilename) + [g] = cls.getSinglePages(cGraphClass, sFilename, bNeighbourhood, bDetach, bLabelled, + iVerbose) + g._index() + if not g.isEmpty(): + if attachEdge and bNeighbourhood: g.collectNeighbors(attachEdge=attachEdge) + if bNeighbourhood: g.collectNeighbors() + if bLabelled: g.parseDocLabels() + if bDetach: g.detachFromDoc() + lGraph.append(g) + return lGraph + + + @classmethod + def getSinglePages(cls + , cGraphClass # graph class (must be subclass) + , sFilename + , bNeighbourhood=True + , bDetach=False + , bLabelled=False + , iVerbose=0): + """ + load a json + Return a Graph object + """ + lGraph = [] + if isinstance(sFilename, io.IOBase): + # we got a file-like object (e.g. in server mode) + doc = json.load(sFilename) + else: + with open(sFilename, encoding='utf-8') as fd: + doc = json.load(fd) + + g = cGraphClass() + g.doc = doc + + # g.lNode, g.lEdge = list(), list() + # now that we have the page, let's create the node for each type! + assert len(g.getNodeTypeList()) == 1, "Not yet implemented" + + # we skip the loop on pages since we always have 1 page for now from the OCR + g.lNode = [nd for nodeType in g.getNodeTypeList() for nd in nodeType._iter_GraphNode(g.doc, sFilename) ] + g.lEdge = Edge.Edge.computeEdges(None, g.lNode, g.iGraphMode) + + if iVerbose >= 2: traceln("\tPage %5d %6d nodes %7d edges" % (1, len(g.lNode), len(g.lEdge))) + + return [g] + + @classmethod + def saveDoc(cls, sFilename, doc, lG, sCreator, sComment): + print("SaveDoc not done ", sFilename) diff --git a/TranskribusDU/graph/Graph_MultiPageXml.py b/TranskribusDU/graph/Graph_MultiPageXml.py index e4a939a..60ad638 100644 --- a/TranskribusDU/graph/Graph_MultiPageXml.py +++ b/TranskribusDU/graph/Graph_MultiPageXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -25,28 +14,31 @@ """ - - - from lxml import etree from common.trace import traceln from xml_formats.PageXml import PageXml -from .Graph import Graph +from .Graph_DOM import Graph_DOM from .Transformer_PageXml import EdgeTransformerClassShifter from .Block import Block, BlockShallowCopy from .Edge import Edge, HorizontalEdge, VerticalEdge, CrossPageEdge, CrossMirrorContinuousPageVerticalEdge from .Page import Page -class Graph_MultiPageXml(Graph): + +class Graph_MultiPageXml(Graph_DOM): ''' Computing the graph for a MultiPageXml document - - USAGE: - - call parseFile to load the DOM and create the nodes and edges - - call detachFromDOM before freeing the DOM ''' + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + sIN_FORMAT = "(Multi)PageXML" # tell here which input format is expected + #Namespace, of PageXml, at least dNS = {"pc":PageXml.NS_PAGE_XML} @@ -54,10 +46,10 @@ class Graph_MultiPageXml(Graph): sxpPage = "//pc:Page" def __init__(self, lNode = [], lEdge = []): - Graph.__init__(self, lNode, lEdge) + Graph_DOM.__init__(self, lNode, lEdge) # --------------------------------------------------------------------------------------------------------- - def _iter_Page_DomNode(self, doc): + def _iter_Page_DocNode(self, doc): """ Parse a Multi-pageXml DOM, by page @@ -78,7 +70,7 @@ def _iter_Page_DomNode(self, doc): page = Page(pnum, pagecnt, iPageWidth, iPageHeight, cls=None, domnode=ndPage, domid=ndPage.get("id")) yield (pnum, page, ndPage) - raise StopIteration() + return # ------------------------------------------------------------------------------------------------------------------------------------------------ @@ -120,7 +112,7 @@ def computeContinuousPageEdges(self, lPrevPageEdgeBlk, lPageBlk, bMirror=True): for blk in lNextHalfPage: blk.mirrorHorizontally(w1) lVirtualPageBlk.extend(lNextHalfPage) - lEdge = Block._findVerticalNeighborEdges(lVirtualPageBlk, CrossMirrorContinuousPageVerticalEdge) + lEdge = Block._findVerticalNeighborEdges_g1(lVirtualPageBlk, CrossMirrorContinuousPageVerticalEdge) #keep only those edge accross pages, and make them to link original blocks! lAllEdge = [CrossMirrorContinuousPageVerticalEdge(edge.A.getOrigBlock(), edge.B.getOrigBlock(), edge.length) \ @@ -143,7 +135,7 @@ def __init__(self, lNode = [], lEdge = []): Graph_MultiPageXml.__init__(self, lNode, lEdge) EdgeTransformerClassShifter.setDefaultEdgeClass([HorizontalEdge, VerticalEdge, CrossPageEdge, CrossMirrorContinuousPageVerticalEdge]) - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -156,7 +148,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): #load the block of each page, keeping the list of blocks of previous page lPrevPageNode = None - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lPageNode = list() setPageNdDomId = set() #the set of DOM id diff --git a/TranskribusDU/graph/Graph_Multi_SinglePageXml.py b/TranskribusDU/graph/Graph_Multi_SinglePageXml.py index b524bb6..f6da7c0 100644 --- a/TranskribusDU/graph/Graph_Multi_SinglePageXml.py +++ b/TranskribusDU/graph/Graph_Multi_SinglePageXml.py @@ -11,18 +11,7 @@ Copyright Xerox(C) 2017 H . Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -47,11 +36,15 @@ class Graph_MultiSinglePageXml(Graph_MultiPageXml): ''' Computing the graph for a MultiPageXml document - - USAGE: - - call parseFile to load the DOM and create the nodes and edges - - call detachFromDOM before freeing the DOM ''' + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + sIN_FORMAT = "MultiSinglePageXML" # tell here which input format is expected + #Namespace, of PageXml, at least dNS = {"pc":PageXml.NS_PAGE_XML} @@ -77,9 +70,7 @@ def loadGraphs(cls for sFilename in lsFilename: if iVerbose: traceln("\t%s"%sFilename) lG= Graph_MultiSinglePageXml.getSinglePages(cGraphClass, sFilename, bNeighbourhood,bDetach,bLabelled, iVerbose) -# if bNeighbourhood: g.collectNeighbors() -# if bLabelled: g.parseDomLabels() -# if bDetach: g.detachFromDOM() + for g in lG: g._index() lGraph.extend(lG) return lGraph @@ -99,7 +90,7 @@ def getSinglePages(cls lGraph=[] doc = etree.parse(sFilename) - for pnum, page, domNdPage in cls._iter_Page_DomNode(doc): + for pnum, page, domNdPage in cls._iter_Page_DocNode(doc): g = cGraphClass() g.doc= doc @@ -133,8 +124,8 @@ def getSinglePages(cls if not g.isEmpty() and len(g.lEdge) > 0: if bNeighbourhood: g.collectNeighbors() - if bLabelled: g.parseDomLabels() - # if bDetach: g.detachFromDOM() + if bLabelled: g.parseDocLabels() + if bDetach: g.detachFromDoc() lGraph.append(g) if iVerbose: traceln("\t\t (%d nodes, %d edges)"%(len(g.lNode), len(g.lEdge)) ) @@ -143,7 +134,7 @@ def getSinglePages(cls # --------------------------------------------------------------------------------------------------------- @classmethod - def _iter_Page_DomNode(cls, doc): + def _iter_Page_DocNode(cls, doc): """ Parse a Multi-pageXml DOM, by page @@ -162,6 +153,6 @@ def _iter_Page_DomNode(cls, doc): page = Page(pnum, pagecnt, iPageWidth, iPageHeight, cls=None, domnode=ndPage, domid=ndPage.get("id")) yield (pnum, page, ndPage) - return + return diff --git a/TranskribusDU/graph/NodeType.py b/TranskribusDU/graph/NodeType.py index 3593307..b5e0282 100644 --- a/TranskribusDU/graph/NodeType.py +++ b/TranskribusDU/graph/NodeType.py @@ -10,18 +10,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -102,13 +91,13 @@ def setXpathExpr(self, o): """ set any Xpath related information to extract the nodes from an XML file """ - raise Exception("Method must be overridden") + raise Exception("Method must be overridden for XML input") def getXpathExpr(self): """ get any Xpath related information to extract the nodes from an XML file """ - raise Exception("Method must be overridden") + raise Exception("Method must be overridden for XML input") def getLabelNameList(self): """ @@ -117,7 +106,7 @@ def getLabelNameList(self): return self.lsLabel @classmethod - def parseDomNodeLabel(cls, domnode): + def parseDocNodeLabel(cls, graph_node): """ return the internal label of the DOM node raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one @@ -125,7 +114,7 @@ def parseDomNodeLabel(cls, domnode): raise Exception("Method must be overridden") @classmethod - def setDomNodeLabel(cls, node, sLabel): + def setDocNodeLabel(cls, graph_node, sLabel): """ Set the DOM node label in the format-dependent way """ diff --git a/TranskribusDU/graph/NodeType_DSXml.py b/TranskribusDU/graph/NodeType_DSXml.py index 4157b21..06cc5c7 100644 --- a/TranskribusDU/graph/NodeType_DSXml.py +++ b/TranskribusDU/graph/NodeType_DSXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -52,13 +41,13 @@ def getXpathExpr(self): """ return self.sxpNode - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = None if domnode.get("title_oracle_best"): sXmlLabel = "title" @@ -80,12 +69,12 @@ def parseDomNodeLabel(self, domnode, defaultCls=None): return sLabel - def setDomNodeLabel(self, domnode, sLabel): + def setDocNodeLabel(self, graph_node, sLabel): """ Set the DOM node label in the format-dependent way """ if sLabel != self.sDefaultLabel: - domnode.set(sLabel, "yes") + graph_node.node.set(sLabel, "yes") return sLabel @@ -124,7 +113,7 @@ def _iter_GraphNode(self, doc, domNdPage, page): yield blk - raise StopIteration() + return def getAttributeInDepth(self, nd, attr): """ diff --git a/TranskribusDU/graph/NodeType_PageXml.py b/TranskribusDU/graph/NodeType_PageXml.py index d50a805..10b0bcd 100644 --- a/TranskribusDU/graph/NodeType_PageXml.py +++ b/TranskribusDU/graph/NodeType_PageXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -77,12 +66,13 @@ def getXpathExpr(self): """ return (self.sxpNode, self.sxpTextual) - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node try: try: sXmlLabel = PageXml.getCustomAttr(domnode, self.sCustAttr_STRUCTURE, self.sCustAttr2_TYPE) @@ -107,12 +97,12 @@ def parseDomNodeLabel(self, domnode, defaultCls=None): return sLabel - def setDomNodeLabel(self, domnode, sLabel): + def setDocNodeLabel(self, graph_node, sLabel): """ Set the DOM node label in the format-dependent way """ if sLabel != self.sDefaultLabel: - PageXml.setCustomAttr(domnode, self.sCustAttr_STRUCTURE, self.sCustAttr2_TYPE, self.dLabel2XmlLabel[sLabel]) + PageXml.setCustomAttr(graph_node.node, self.sCustAttr_STRUCTURE, self.sCustAttr2_TYPE, self.dLabel2XmlLabel[sLabel]) return sLabel @@ -175,7 +165,7 @@ def _iter_GraphNode(self, doc, domNdPage, page): yield blk - raise StopIteration() + return def _get_GraphNodeText(self, doc, domNdPage, ndBlock): """ @@ -188,7 +178,8 @@ def _get_GraphNodeText(self, doc, domNdPage, ndBlock): lNdText = ndBlock.xpath(self.sxpTextual, namespaces=self.dNS) if len(lNdText) != 1: if len(lNdText) <= 0: - raise ValueError("I found no useful TextEquiv below this node... \n%s"%etree.tostring(ndBlock)) + # raise ValueError("I found no useful TextEquiv below this node... \n%s"%etree.tostring(ndBlock)) + return None else: raise ValueError("I expected exactly one useful TextEquiv below this node. Got many... \n%s"%etree.tostring(ndBlock)) @@ -214,12 +205,13 @@ def setLabelAttribute(self, sAttrName="type"): """ self.sLabelAttr = sAttrName - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) try: sLabel = self.dXmlLabel2Label[sXmlLabel] @@ -241,12 +233,12 @@ def parseDomNodeLabel(self, domnode, defaultCls=None): return sLabel - def setDomNodeLabel(self, domnode, sLabel): + def setDocNodeLabel(self, graph_node, sLabel): """ Set the DOM node label in the format-dependent way """ if sLabel != self.sDefaultLabel: - domnode.set(self.sLabelAttr, self.dLabel2XmlLabel[sLabel]) + graph_node.node.set(self.sLabelAttr, self.dLabel2XmlLabel[sLabel]) return sLabel class NodeType_PageXml_type_woText(NodeType_PageXml_type): @@ -299,13 +291,17 @@ def test_getset(): """ + class MyNode: + def __init__(self, nd): self.node = nd + doc = etree.parse(BytesIO(sXml)) nd = doc.getroot() + graph_node = MyNode(nd) obj = NodeType_PageXml("foo", ["page-number", "index"]) - assert obj.parseDomNodeLabel(nd) == 'foo_page-number', obj.parseDomNodeLabel(nd) - assert obj.parseDomNodeLabel(nd, "toto") == 'foo_page-number' - assert obj.setDomNodeLabel(nd, "foo_index") == 'foo_index' - assert obj.parseDomNodeLabel(nd) == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_page-number', obj.parseDocNodeLabel(nd) + assert obj.parseDocNodeLabel(graph_node) == 'foo_page-number' + assert obj.setDocNodeLabel(graph_node, "foo_index") == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_index' def test_getset2(): from lxml import etree @@ -316,12 +312,16 @@ def test_getset2(): """ + class MyNode: + def __init__(self, nd): self.node = nd + doc = etree.parse(BytesIO(sXml)) nd = doc.getroot() + graph_node = MyNode(nd) obj = NodeType_PageXml("foo", ["page-number", "index"], [""]) - assert obj.parseDomNodeLabel(nd) == 'foo_OTHER' - assert obj.setDomNodeLabel(nd, "foo_index") == 'foo_index' - assert obj.parseDomNodeLabel(nd) == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_OTHER' + assert obj.setDocNodeLabel(graph_node, "foo_index") == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_index' def test_getset3(): @@ -334,10 +334,14 @@ def test_getset3(): """ + class MyNode: + def __init__(self, nd): self.node = nd + doc = etree.parse(BytesIO(sXml)) nd = doc.getroot() + graph_node = MyNode(nd) obj = NodeType_PageXml("foo", ["page-number", "index"], [""], bOther=False) with pytest.raises(PageXmlException): - assert obj.parseDomNodeLabel(nd) == 'foo_OTHER' - assert obj.setDomNodeLabel(nd, "foo_index") == 'foo_index' - assert obj.parseDomNodeLabel(nd) == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_OTHER' + assert obj.setDocNodeLabel(graph_node) == 'foo_index' + assert obj.parseDocNodeLabel(graph_node) == 'foo_index' diff --git a/TranskribusDU/graph/NodeType_jsonOCR.py b/TranskribusDU/graph/NodeType_jsonOCR.py new file mode 100644 index 0000000..ec7381d --- /dev/null +++ b/TranskribusDU/graph/NodeType_jsonOCR.py @@ -0,0 +1,142 @@ +import types +import shapely.geometry as geom + + +from common.trace import traceln +from util.Polygon import Polygon + +from .NodeType import NodeType +from .Block import Block +from .Page import Page +#from PIL import Image + + +def defaultBBoxDeltaFun(w): + """ + When we reduce the width or height of a bounding box, we use this function to compute the deltaX or deltaY + , which is applied on x1 and x2 or y1 and y2 + + For instance, for horizontal axis + x1 = x1 + deltaFun(abs(x1-x2)) + x2 = x2 + deltaFun(abs(x1-x2)) + """ + # "historically, we were doing: + dx = max(w * 0.066, min(20, w / 3)) + # for ABP table RV is doing: dx = max(w * 0.066, min(5, w/3)) , so this function can be defined by the caller. + return dx + + + +class NodeType_jsonOCR(NodeType): + # where the labels can be found in the data + sCustAttr_STRUCTURE = "structure" + sCustAttr2_TYPE = "type" + + + def __init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel=None, bOther=True, BBoxDeltaFun=defaultBBoxDeltaFun): + NodeType.__init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel, bOther) + + self.BBoxDeltaFun = BBoxDeltaFun + if self.BBoxDeltaFun is not None and type(self.BBoxDeltaFun) != types.FunctionType: + raise Exception("Error: BBoxDeltaFun must be None or a function (or a lambda)") + + def setXpathExpr(self, t_sxpNode_sxpTextual): + """ + generalisation of XPATH to JSON format? + """ + raise Exception("Not yet implemented") + + def getXpathExpr(self): + """ + generalisation of XPATH to JSON format? + """ + raise Exception("Not yet implemented") + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + raise Exception("Not yet implemented") + + def setDocNodeLabel(self, graph_node, sLabel): + """ + Set the DOM node label in the format-dependent way + """ +# print("setDocNodeLabel " +# , "%s '%s'" %(graph_node.getShape(), graph_node.text) +# , " ", sLabel) + pass + # raise Exception("Not yet implemented") + + def setLabelAttribute(self, sAttrName="type"): + """ + set the name of the Xml attribute that contains the label + """ + self.sLabelAttr = sAttrName + # --------------------------------------------------------------------------------------------------------- + +# def getPageWidthandHeight(self, filename): +# +# img_dir = '/nfs/project/nmt/menus/data/GT500/' +# file = img_dir + filename.split('/')[-1][:-4] + 'jpg' +# im = Image.open(file) +# return im.size + # --------------------------------------------------------------------------------------------------------- + + def getPointList(self, ndBlock): + coords = ndBlock['boundingBox'] + x = coords[0] + y = coords[1] + w = coords[2] + h = coords[3] + return [(x,y), (x+w, y),(x+w, y+h), (x, y+h)] + # ---------------------------------------------------------------------------------------------------------- + + def _iter_GraphNode(self, doc, sFilename, page=None): + """ + Get the json dict + + iterator on the DOM, that returns nodes (of class Block) + """ + # --- XPATH contexts + + lNdBlock = doc['GlynchResults'] + #page_w, page_h = self.getPageWidthandHeight(sFilename) + page = Page(1, 1, 1, 1) + for ndBlock in lNdBlock: + sText = ndBlock['label'] + if sText == None: + sText = "" + traceln("Warning: no text in node") + # raise ValueError, "No text in node: %s"%ndBlock + + # now we need to infer the bounding box of that object + lXY = self.getPointList(ndBlock) # the polygon + if lXY == []: + continue + + plg = Polygon(lXY) + try: + x1, y1, x2, y2 = plg.fitRectangle() + except ZeroDivisionError: + x1, y1, x2, y2 = plg.getBoundingBox() + except ValueError: + x1, y1, x2, y2 = plg.getBoundingBox() + + # we reduce a bit this rectangle, to ovoid overlap + if not (self.BBoxDeltaFun is None): + w, h = x2 - x1, y2 - y1 + dx = self.BBoxDeltaFun(w) + dy = self.BBoxDeltaFun(h) + x1, y1, x2, y2 = [int(round(v)) for v in [x1 + dx, y1 + dy, x2 - dx, y2 - dy]] + + # TODO + orientation = 0 # no meaning for PageXml + classIndex = 0 # is computed later on + # and create a Block + blk = Block(page, (x1, y1, x2 - x1, y2 - y1), sText, orientation, classIndex, self, ndBlock, domid=None) + blk.setShape(geom.Polygon(lXY)) + yield blk + + return diff --git a/TranskribusDU/graph/Page.py b/TranskribusDU/graph/Page.py index b7de43d..a167799 100644 --- a/TranskribusDU/graph/Page.py +++ b/TranskribusDU/graph/Page.py @@ -32,7 +32,7 @@ def __init__(self, pnum, pagecnt, w, h, cls=None, domnode=None, domid=None): self.bEven = (pnum%2 == 0) - def detachFromDOM(self): + def detachFromDoc(self): """ Erase any pointer to the DOM so that we can free it. """ diff --git a/TranskribusDU/graph/Transformer.py b/TranskribusDU/graph/Transformer.py index d217ecc..379af5e 100644 --- a/TranskribusDU/graph/Transformer.py +++ b/TranskribusDU/graph/Transformer.py @@ -6,18 +6,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/Transformer_Generic.py b/TranskribusDU/graph/Transformer_Generic.py new file mode 100644 index 0000000..d9cf431 --- /dev/null +++ b/TranskribusDU/graph/Transformer_Generic.py @@ -0,0 +1,362 @@ +# -*- coding: utf-8 -*- + +""" + Node and edge feature transformers to extract features from a graph + + No link with DOm or JSON => named GENERIC + + + So, it should work equally well whatever the input format is (XML, JSON) since + it uses only the node and edge geometric attributes and text + + Copyright Naver(C) 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import numpy as np + +from .Transformer import Transformer +from .Edge import HorizontalEdge, VerticalEdge, SamePageEdge +import graph.Transformer_PageXml as Transformer_PageXml + +fALIGNMENT_COEF = 6.0 + +#------------------------------------------------------------------------------------------------------ +# HERE IS WHAT IS UNCHANGED, because ALWAYS COMPUTABEL FROM ANY GRAPH + +NodeTransformerText = Transformer_PageXml.NodeTransformerText +NodeTransformerTextEnclosed = Transformer_PageXml.NodeTransformerTextEnclosed +NodeTransformerNeighborText = Transformer_PageXml.NodeTransformerNeighborText +NodeTransformerTextLen = Transformer_PageXml.NodeTransformerTextLen +Node1ConstantFeature = Transformer_PageXml.Node1ConstantFeature + +EdgeTransformerByClassIndex = Transformer_PageXml.EdgeTransformerByClassIndex +EdgeTransformerSourceText = Transformer_PageXml.EdgeTransformerSourceText +EdgeTransformerTargetText = Transformer_PageXml.EdgeTransformerTargetText +EdgeTransformerClassShifter = Transformer_PageXml.EdgeTransformerClassShifter + + +#------------------------------------------------------------------------------------------------------ +class NodeTransformerXYWH(Transformer): + """ + In this version: + - we do not use the page width and height to normalise, but max(x2) and max(y2) + - width and heights also normalised by mean(w) and mean(h) + - we do not consider odd/even pages + + Added by Nitin + Updated by JL + we will get a list of block and need to send back what StandardScaler needs for in-place scaling, a numpy array!. + So we return a numpy array + """ + def transform(self, lNode): +# a = np.empty( ( len(lNode), 5 ) , dtype=np.float64) +# for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 + a = np.empty( ( len(lNode), 4+4+4 ) , dtype=np.float64) + + try: + max_x = max(o.x2 for o in lNode) + max_y = max(o.y2 for o in lNode) + mean_w = sum(abs(o.x1 - o.x2) for o in lNode) / len(lNode) + mean_h = sum(abs(o.y1 - o.y2) for o in lNode) / len(lNode) + except (ValueError, ZeroDivisionError): + max_x, max_y, mean_w, mean_h = None, None, None, None + + for i, blk in enumerate(lNode): + x1,y1,x2,y2 = blk.x1, blk.y1, blk.x2, blk.y2 + w = abs(x1-x2) / mean_w + h = abs(y1-y2) / mean_h + x1,x2 = x1/max_x, x2/max_x + y1,y2 = y1/max_y, y2/max_y + a[i, :] = [ x1, x1*x1 + , x2, x2*x2 + , y1, y1*y1 + , y2, y2*y2 + , w, w * w + , h, h * h + ] + return a + +#------------------------------------------------------------------------------------------------------ +class NodeTransformerNeighbors(Transformer): + """ + Characterising the neighborough + """ + def transform(self, lNode): +# a = np.empty( ( len(lNode), 5 ) , dtype=np.float64) +# for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 + a = np.empty( ( len(lNode), 2 + 2 ) , dtype=np.float64) + for i, blk in enumerate(lNode): + ax1, ay1 = blk.x1, blk.y1 + #number of horizontal/vertical/crosspage neighbors + a[i,0:2] = len(blk.lHNeighbor), len(blk.lVNeighbor) + #number of horizontal/vertical/crosspage neighbors occuring after this block + a[i,2:4] = (sum(1 for _b in blk.lHNeighbor if _b.x1 > ax1), + sum(1 for _b in blk.lVNeighbor if _b.y1 > ay1)) + return a + + +#------------------------------------------------------------------------------------------------------- +class NodeTransformerNeighborsAllText(Transformer): + """ + Collects all the text from the neighbors + On going ... + """ + def transform(self, lNode): + txt_list=[] + #print('Node Text',lNode.text) + for _i,blk in enumerate(lNode): + txt_H = ' '.join(o.text for o in blk.lHNeighbor) + txt_V = ' '.join(o.text for o in blk.lVNeighbor) + txt_list.append(' '.join([txt_H, txt_V])) + + return txt_list + +#------------------------------------------------------------------------------------------------------ +class Node1HotTextFeatures(Transformer): + """ + we will get a list of block and return a one-hot encoding, directly + """ + def transform(self, lNode): + a = np.zeros( ( len(lNode), 6 ) , dtype=np.float64) + for i, blk in enumerate(lNode): + s = blk.text + a[i,0:7] = ( s.isalnum(), + s.isalpha(), + s.isdigit(), + s.islower(), + s.istitle(), + s.isupper()) + return a + + +class Node1ConstantFeature(Transformer): + """ + we generate one constant feature per node. (1.0) + """ + def transform(self, lNode): + return np.ones( ( len(lNode), 1 ) , dtype=np.float64) + + +#------------------------------------------------------------------------------------------------------ + +class EdgeBooleanAlignmentFeatures(EdgeTransformerClassShifter): + """ + we will get a list of edges and return a boolean array, directly + + We ignore the page information + + vertical-, horizontal- centered (at epsilon precision, epsilon typically being 5pt ?) + left-, top-, right-, bottom- justified (at epsilon precision) + """ + nbFEAT = 6 + + def transform(self, lEdge): + #DISC a = np.zeros( ( len(lEdge), 16 ) , dtype=np.float64) + a = - np.ones( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) + + try: + mean_h_A = sum(abs(o.A.y1 - o.A.y2) for o in lEdge) / len(lEdge) + mean_h_B = sum(abs(o.B.y1 - o.B.y2) for o in lEdge) / len(lEdge) + mean_h = (mean_h_A + mean_h_B) / 2 + except (ValueError, ZeroDivisionError): + mean_h = None + + # When to decide of an alignment + thH = mean_h / fALIGNMENT_COEF + + for i, edge in enumerate(lEdge): + z = self._dEdgeClassIndexShift[edge.__class__] + + A,B = edge.A, edge.B + + a[i,z:z+self.nbFEAT] = ( A.x1 + A.x2 - (B.x1 + B.x2) <= thH, # centering + A.y1 + A.y2 - (B.y1 + B.y2) <= thH, + abs(A.x1-B.x1) <= thH, #justified + abs(A.y1-B.y1) <= thH, + abs(A.x2-B.x2) <= thH, + abs(A.y2-B.y2) <= thH + ) + return a + + +#------------------------------------------------------------------------------------------------------ + + + +class EdgeNumericalSelector(EdgeTransformerClassShifter): + """ + getting rid of the hand-crafted thresholds + JLM Nov 2019: simpler and better (normalization must not change with direction for the 2 removed any direction features) + """ + nbFEAT = 6 + + def transform(self, lEdge): + #no font size a = np.zeros( ( len(lEdge), 5 ) , dtype=np.float64) +# a = np.zeros( ( len(lEdge), 7 ) , dtype=np.float64) + a = np.zeros( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) + + try: + mean_length_h = sum(o.length for o in lEdge if isinstance(o, HorizontalEdge)) / len(lEdge) + except ZeroDivisionError: + mean_length_h = None + try: + mean_length_v = sum(o.length for o in lEdge if not isinstance(o, HorizontalEdge)) / len(lEdge) + except ZeroDivisionError: + mean_length_v = None + + for i, edge in enumerate(lEdge): + z = self._dEdgeClassIndexShift[edge.__class__] + A,B = edge.A, edge.B + + #overlap + ovr = A.significantOverlap(B, 0) + try: + a[i, z+0] = ovr / (A.area() + B.area() - ovr) + except ZeroDivisionError: + pass + + # + na, nb = len(A.text), len(B.text) + lcs = lcs_length(A.text,na, B.text,nb) + try: + a[i, z+1] = float( lcs / (na+nb-lcs) ) + except ZeroDivisionError: + pass + + #new in READ: the length of a same-page edge + if isinstance(edge, SamePageEdge): + if isinstance(edge, HorizontalEdge): + norm_length = edge.length / mean_length_h + # Horiz. Vert. Horiz. Vert. + a[i,z+2:z+8] = (0.0, edge.length, 0.0 , norm_length , 0.0 , norm_length*norm_length ) + else: + norm_length = edge.length / mean_length_v + a[i,z+2:z+8] = (edge.length, 0.0, norm_length, 0.0 , norm_length*norm_length , 0.0 ) + + return a + + +class EdgeNumericalSelector_noText(EdgeTransformerClassShifter): + nbFEAT = 5 + + def transform(self, lEdge): + a = np.zeros( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) + + try: + mean_length = sum(o.length for o in lEdge) / len(lEdge) + except ZeroDivisionError: + mean_length = None + + for i, edge in enumerate(lEdge): + z = self._dEdgeClassIndexShift[edge.__class__] + A,B = edge.A, edge.B + + #overlap + ovr = A.significantOverlap(B, 0) + try: + a[i, z+0] = ovr / (A.area() + B.area() - ovr) + except ZeroDivisionError: + pass + + #new in READ: the length of a same-page edge + if isinstance(edge, SamePageEdge): + if isinstance(edge, VerticalEdge): + norm_length = edge.length / mean_length + # Horiz. Vert. Horiz. Vert. + a[i,z+1:z+7] = (0.0, edge.length, 0.0 , norm_length , 0.0 , norm_length*norm_length) + else: + norm_length = edge.length / mean_length + a[i,z+1:z+7] = (edge.length, 0.0, norm_length, 0.0 , norm_length*norm_length , 0.0) + + return a + + +#------------------------------------------------------------------------------------------------------ + +class EdgeTypeFeature_HV(Transformer): + """ + Only tells the type of edge: Horizontal or Vertical + """ + def transform(self, lEdge): + a = np.zeros( (len(lEdge), 2), dtype=np.float64) + for i, edge in enumerate(lEdge): + #-- vertical / horizontal + if edge.__class__ == HorizontalEdge: + a[i,0] = 1.0 + else: + assert edge.__class__ == VerticalEdge + a[i,1] = 1.0 + return a + + + + +# ----------------------------------------------------------------------------------------------------------------------------- +def _debug(lO, a): + for i,o in enumerate(lO): + print(o) + print(a[i]) + print() + +def lcs_length(a,na, b,nb): + """ + Compute the length of the longest common string. Very fast. JLM March 2016 + + NOTE: I did not compare against fastlcs... + """ + #na, nb = len(a), len(b) + if nb < na: a, na, b, nb = b, nb, a, na + if na==0: return 0 + na1 = na+1 + curRow = [0]*na1 + prevRow = [0]*na1 + range1a1 = range(1, na1) + for i in range(nb): + bi = b[i] + prevRow, curRow = curRow, prevRow + curRow[0] = 0 + curRowj = 0 + for j in range1a1: + if bi == a[j-1]: + curRowj = max(1+prevRow[j-1], prevRow[j], curRowj) + else: + curRowj = max(prevRow[j], curRowj) + curRow[j] = curRowj + return curRowj + + +class NodeEdgeTransformer(Transformer): + """ + we will get a list of list of edges ... + """ + def __init__(self,edge_list_transformer,agg_func='sum'): + self.agg_func=agg_func + self.edge_list_transformer=edge_list_transformer + + def transform(self,lNode): + x_all=[] + for _i, blk in enumerate(lNode): + x_edge_node = self.edge_list_transformer.transform(blk.edgeList) + if self.agg_func=='sum': + x_node=x_edge_node.sum(axis=0) + elif self.agg_func=='mean': + x_node=x_edge_node.mean(axis=0) + else: + raise ValueError('Invalid Argument',self.agg_func) + x_all.append(x_node) + return np.vstack(x_all) + + + + diff --git a/TranskribusDU/graph/Transformer_Logit.py b/TranskribusDU/graph/Transformer_Logit.py index e6dfb8b..8e147d3 100644 --- a/TranskribusDU/graph/Transformer_Logit.py +++ b/TranskribusDU/graph/Transformer_Logit.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/Transformer_PageXml.py b/TranskribusDU/graph/Transformer_PageXml.py index 1b7d279..d4ffa1c 100644 --- a/TranskribusDU/graph/Transformer_PageXml.py +++ b/TranskribusDU/graph/Transformer_PageXml.py @@ -7,18 +7,7 @@ v2 March 2017 JLM - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -97,44 +86,58 @@ class NodeTransformerXYWH(Transformer): def transform(self, lNode): # a = np.empty( ( len(lNode), 5 ) , dtype=np.float64) # for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 - a = np.empty( ( len(lNode), 2+4+2+4 ) , dtype=np.float64) + a = np.empty( ( len(lNode), 2+4+4 ) , dtype=np.float64) for i, blk in enumerate(lNode): page = blk.page x1,y1,x2,y2 = blk.x1, blk.y1, blk.x2, blk.y2 w,h = float(page.w), float(page.h) - #Normalize by page with and height - xn1, yn1, xn2, yn2 = x1/w, y1/h, x2/w, y2/h + #Normalize by page with and height to range (-1, +1] + xn1, yn1, xn2, yn2 = 2*x1/w-1, 2*y1/h-1, 2*x2/w-1, 2*y2/h-1 #generate X-from-binding if page.bEven: - xb1, xb2 = w - x2 , w - x1 - xnb1, xnb2 = 1.0 - xn2 , 1.0 - xn1 + xnb1, xnb2 = -xn2 , -xn1 else: - xb1, xb2 = x1 , x2 - xnb1, xnb2 = xn1 , xn2 - a[i, :] = [xb1, xb2 , x1, y2, x2-x1, y2-y1 , xnb1, xnb2 , xn1, yn2, xn2-xn1, yn2-yn1] + xnb1, xnb2 = xn1 , xn2 + a[i, :] = [xnb1, xnb2 , xn1, yn1, xn2-xn1, yn2-yn1 , xn1*xn1, yn1*yn1, xn2*xn2, yn2*yn2] return a -class NodeTransformerXYWH_v2(Transformer): +#------------------------------------------------------------------------------------------------------ + +class NodeTransformerXYWH_NoPage(Transformer): """ + In this version: + - we do not use the page width and height to normalise, but max(x2) and max(y2) + - width and heights also normalised by mean(w) and mean(h) + - we do not consider odd/even pages + + Added by Nitin + Updated by JL we will get a list of block and need to send back what StandardScaler needs for in-place scaling, a numpy array!. - So we return a numpy array + So we return a numpy array """ def transform(self, lNode): # a = np.empty( ( len(lNode), 5 ) , dtype=np.float64) -# for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 - a = np.empty( ( len(lNode), 2+4+4 ) , dtype=np.float64) - for i, blk in enumerate(lNode): - page = blk.page +# for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 + a = np.empty( ( len(lNode), 4+4+4 ) , dtype=np.float64) + + try: + max_x = max(o.x2 for o in lNode) + max_y = max(o.y2 for o in lNode) + mean_w = sum(abs(o.x1 - o.x2) for o in lNode) / len(lNode) + mean_h = sum(abs(o.y1 - o.y2) for o in lNode) / len(lNode) + except (ValueError, ZeroDivisionError): + max_x, max_y, mean_w, mean_h = None, None, None, None + + for i, blk in enumerate(lNode): x1,y1,x2,y2 = blk.x1, blk.y1, blk.x2, blk.y2 - w,h = float(page.w), float(page.h) - #Normalize by page with and height to range (-1, +1] - xn1, yn1, xn2, yn2 = 2*x1/w-1, 2*y1/h-1, 2*x2/w-1, 2*y2/h-1 - #generate X-from-binding - if page.bEven: - xnb1, xnb2 = -xn2 , -xn1 - else: - xnb1, xnb2 = xn1 , xn2 - a[i, :] = [xnb1, xnb2 , xn1, yn1, xn2-xn1, yn2-yn1 , xn1*xn1, yn1*yn1, xn2*xn2, yn2*yn2] + w = abs(x1-x2) / mean_w + h = abs(y1-y2) / mean_h + x1,x2 = x1/max_x, x2/max_x + y1,y2 = y1/max_y, y2/max_y + a[i, :] = [ x1, x2, x1*x1, x2*x2 + , y1, y2, y1*y1, y2*y2 + , w , h , w *w , h *h + ] return a #------------------------------------------------------------------------------------------------------ @@ -176,26 +179,10 @@ def transform(self, lNode): txt_block.append(blk_neighbor.text) for blk_neighbor in blk.lVNeighbor: txt_block.append(blk_neighbor.text) - for blk_neighbor in blk.lVNeighbor: + for blk_neighbor in blk.lCPNeighbor: txt_block.append(blk_neighbor.text) - #txt_block +=" ".join([nd2.text for nd2 in blk.lHNeighbor]) - #txt_block +=" ".join([nd2.text for nd2 in blk.lVNeighbor]) - #txt_block +=" ".join([nd2.text for nd2 in blk.lCPNeighbor]) - - #print(txt_block) - ''' - for b in blk.lHNeighbor: - #why doing a if if _b.x1 > ax1) - txt_list.append(b.text) - for b in blk.lVNeighbor: - txt_list.append(b.text) - for b in blk.lCPNeighbor: - txt_list.append(b.text) - ''' txt_list.append(' '.join(txt_block)) - #print('TEXT List',txt_list) - print('LEN TEXT LIST',len(txt_list)) return txt_list @@ -459,34 +446,8 @@ def transform(self, lEdge): ) return a -class Edge1HotFeatures_noText(EdgeTransformerClassShifter): - """ - we will get a list of edges and return a boolean array, directly - - above/below, left/right, neither but on same page - same or consecutive pages - vertical-, horizontal- centered (at epsilon precision, epsilon typically being 5pt ?) - left-, top-, right-, bottom- justified (at epsilon precision) - TODO sequentiality of content - TODO crossing ruling-line - - noText = no textual feature - """ - - nbFEAT = 2 - def __init__(self, pageNumSequenciality, bMirrorPage=True): - EdgeTransformerClassShifter.__init__(self, bMirrorPage) - self.sqnc = pageNumSequenciality - - def transform(self, lEdge): - #a = np.zeros( ( len(lEdge), 3 + 17*3 ) , dtype=np.float64) - a = np.zeros( ( len(lEdge), self.nbFEAT), dtype=np.float64) - for i, edge in enumerate(lEdge): - #-- vertical / horizontal / virtual / cross-page / not-neighbor - a[i, :] = (1, self._dEdgeClassIndexShift[edge.__class__]) - return a #------------------------------------------------------------------------------------------------------ class EdgeBooleanFeatures(EdgeTransformerClassShifter): """ @@ -497,31 +458,6 @@ class EdgeBooleanFeatures(EdgeTransformerClassShifter): """ nbFEAT = 6 - def transform(self, lEdge): - #DISC a = np.zeros( ( len(lEdge), 16 ) , dtype=np.float64) - a = - np.ones( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) - for i, edge in enumerate(lEdge): - z = self._dEdgeClassIndexShift[edge.__class__] - - A,B = edge.A, edge.B - a[i,z:z+self.nbFEAT] = ( A.x1 + A.x2 - (B.x1 + B.x2) <= 2 * fEPSILON, # centering - A.y1 + A.y2 - (B.y1 + B.y2) <= 2 * fEPSILON, - abs(A.x1-B.x1) <= fEPSILON, #justified - abs(A.y1-B.y1) <= fEPSILON, - abs(A.x2-B.x2) <= fEPSILON, - abs(A.y2-B.y2) <= fEPSILON - ) - return a - -class EdgeBooleanFeatures_v2(EdgeTransformerClassShifter): - """ - we will get a list of edges and return a boolean array, directly - - vertical-, horizontal- centered (at epsilon precision, epsilon typically being 5pt ?) - left-, top-, right-, bottom- justified (at epsilon precision) - """ - nbFEAT = 6 - def transform(self, lEdge): #DISC a = np.zeros( ( len(lEdge), 16 ) , dtype=np.float64) a = - np.ones( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) @@ -541,17 +477,14 @@ def transform(self, lEdge): ) return a + #------------------------------------------------------------------------------------------------------ class EdgeNumericalSelector(EdgeTransformerClassShifter): """ - we will get a list of block and need to send back what StandardScaler needs for in-place scaling, a numpy array!. - - overlap size (ratio of intersection to union of surfaces) - max(overlap size, 5000) - identical content in [0, 1] as ratio of lcs to "union" - max( lcs, 25) + getting rid of the hand-crafted thresholds + JLM Nov 2019: simpler and better (normalization must not change with direction for the 2 removed any direction features) """ - nbFEAT = 11 + nbFEAT = 6 def transform(self, lEdge): #no font size a = np.zeros( ( len(lEdge), 5 ) , dtype=np.float64) @@ -567,129 +500,32 @@ def transform(self, lEdge): a[i, z+0] = ovr / (A.area() + B.area() - ovr) except ZeroDivisionError: pass - a[i, z+1] = min(ovr, 5000.0) # na, nb = len(A.text), len(B.text) lcs = lcs_length(A.text,na, B.text,nb) try: - a[i, z+2] = float( lcs / (na+nb-lcs) ) + a[i, z+1] = float( lcs / (na+nb-lcs) ) except ZeroDivisionError: pass - a[i, z+3:z+5] = min(lcs, 50.0), min(lcs, 100.0) - #a[i, z+4] = min(lcs, 100.0) #new in READ: the length of a same-page edge if isinstance(edge, SamePageEdge): if isinstance(edge, VerticalEdge): norm_length = edge.length / float(edge.A.page.h) norm_length2 = norm_length * norm_length - # Horiz. Vert. Any direction Horiz. Vert. Any direction - a[i,z+5:z+11] = (0.0 , norm_length , norm_length , 0.0 , norm_length2 , norm_length2) + # Horiz. Vert. Horiz. Vert. + a[i,z+2:z+6] = (0.0 , norm_length , 0.0 , norm_length2 ) else: norm_length = edge.length / float(edge.A.page.w) norm_length2 = norm_length * norm_length - a[i,z+5:z+11] = (norm_length, 0.0 , norm_length , norm_length2 , 0.0 , norm_length2) + a[i,z+2:z+6] = (norm_length, 0.0 , norm_length2 , 0.0 ) -# #fontsize -# a[i, z+5] = B.fontsize - A.fontsize -# a[i, z+6] = (B.fontsize+1) / (A.fontsize+1) - return a -class EdgeNumericalSelector_noText(EdgeTransformerClassShifter): - """ - we will get a list of block and need to send back what StandardScaler needs for in-place scaling, a numpy array!. - overlap size (ratio of intersection to union of surfaces) - max(overlap size, 5000) - identical content in [0, 1] as ratio of lcs to "union" - max( lcs, 25) - noText = no textual features - """ - nbFEAT = 8 - - def transform(self, lEdge): - #no font size a = np.zeros( ( len(lEdge), 5 ) , dtype=np.float64) -# a = np.zeros( ( len(lEdge), 7 ) , dtype=np.float64) - a = np.zeros( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) - for i, edge in enumerate(lEdge): - z = self._dEdgeClassIndexShift[edge.__class__] - A,B = edge.A, edge.B - - #overlap - ovr = A.significantOverlap(B, 0) - try: - a[i, z+0] = ovr / (A.area() + B.area() - ovr) - except ZeroDivisionError: - pass - a[i, z+1] = min(ovr, 5000.0) - - #new in READ: the length of a same-page edge - if isinstance(edge, SamePageEdge): - if isinstance(edge, VerticalEdge): - norm_length = edge.length / float(edge.A.page.h) - norm_length2 = norm_length * norm_length - # Horiz. Vert. Any direction Horiz. Vert. Any direction - a[i,z+2:z+8] = (0.0 , norm_length , norm_length , 0.0 , norm_length2 , norm_length2) - else: - norm_length = edge.length / float(edge.A.page.w) - norm_length2 = norm_length * norm_length - a[i,z+2:z+8] = (norm_length, 0.0 , norm_length , norm_length2 , 0.0 , norm_length2) - - return a - -class EdgeNumericalSelector_v2(EdgeTransformerClassShifter): - """ - getting rid of the hand-crafted thresholds - """ - nbFEAT =8 - - def transform(self, lEdge): - #no font size a = np.zeros( ( len(lEdge), 5 ) , dtype=np.float64) -# a = np.zeros( ( len(lEdge), 7 ) , dtype=np.float64) - a = np.zeros( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) - for i, edge in enumerate(lEdge): - z = self._dEdgeClassIndexShift[edge.__class__] - A,B = edge.A, edge.B - - #overlap - ovr = A.significantOverlap(B, 0) - try: - a[i, z+0] = ovr / (A.area() + B.area() - ovr) - except ZeroDivisionError: - pass - #a[i, z+1] = min(ovr, 5000.0) - - # - na, nb = len(A.text), len(B.text) - lcs = lcs_length(A.text,na, B.text,nb) - try: - a[i, z+1] = float( lcs / (na+nb-lcs) ) - except ZeroDivisionError: - pass - # a[i, z+3:z+5] = min(lcs, 50.0), min(lcs, 100.0) - - #new in READ: the length of a same-page edge - if isinstance(edge, SamePageEdge): - if isinstance(edge, VerticalEdge): - norm_length = edge.length / float(edge.A.page.h) - norm_length2 = norm_length * norm_length - # Horiz. Vert. Any direction Horiz. Vert. Any direction - a[i,z+2:z+8] = (0.0 , norm_length , norm_length , 0.0 , norm_length2 , norm_length2) - else: - norm_length = edge.length / float(edge.A.page.w) - norm_length2 = norm_length * norm_length - a[i,z+2:z+8] = (norm_length, 0.0 , norm_length , norm_length2 , 0.0 , norm_length2) - -# #fontsize -# a[i, z+5] = B.fontsize - A.fontsize -# a[i, z+6] = (B.fontsize+1) / (A.fontsize+1) - - return a - -class EdgeNumericalSelector_v2_noText(EdgeTransformerClassShifter): - nbFEAT = 7 +class EdgeNumericalSelector_noText(EdgeTransformerClassShifter): + nbFEAT = 5 def transform(self, lEdge): a = np.zeros( ( len(lEdge), self._nbEdgeFeat ) , dtype=np.float64) @@ -709,15 +545,16 @@ def transform(self, lEdge): if isinstance(edge, VerticalEdge): norm_length = edge.length / float(edge.A.page.h) norm_length2 = norm_length * norm_length - # Horiz. Vert. Any direction Horiz. Vert. Any direction - a[i,z+1:z+7] = (0.0 , norm_length , norm_length , 0.0 , norm_length2 , norm_length2) + # Horiz. Vert. Horiz. Vert. + a[i,z+1:z+5] = (0.0 , norm_length , 0.0 , norm_length2) else: norm_length = edge.length / float(edge.A.page.w) norm_length2 = norm_length * norm_length - a[i,z+1:z+7] = (norm_length, 0.0 , norm_length , norm_length2 , 0.0 , norm_length2) + a[i,z+1:z+5] = (norm_length, 0.0 , norm_length2 , 0.0) return a + #------------------------------------------------------------------------------------------------------ class EdgeTypeFeature_HV(Transformer): @@ -736,6 +573,7 @@ def transform(self, lEdge): return a + # ----------------------------------------------------------------------------------------------------------------------------- def _debug(lO, a): for i,o in enumerate(lO): diff --git a/TranskribusDU/graph/factorial/FactorialGraph.py b/TranskribusDU/graph/factorial/FactorialGraph.py index 51aa094..e66007b 100644 --- a/TranskribusDU/graph/factorial/FactorialGraph.py +++ b/TranskribusDU/graph/factorial/FactorialGraph.py @@ -5,18 +5,7 @@ Copyright Naver(C) 2018 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/factorial/FactorialGraph_MultiContinuousPageXml.py b/TranskribusDU/graph/factorial/FactorialGraph_MultiContinuousPageXml.py index d6031fb..87b1e75 100644 --- a/TranskribusDU/graph/factorial/FactorialGraph_MultiContinuousPageXml.py +++ b/TranskribusDU/graph/factorial/FactorialGraph_MultiContinuousPageXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -53,7 +42,7 @@ def __init__(self, lNode = [], lEdge = []): EdgeTransformerClassShifter.setDefaultEdgeClass([HorizontalEdge, VerticalEdge, CrossPageEdge, CrossMirrorContinuousPageVerticalEdge]) - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -68,7 +57,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): nodeType0 = self.getNodeTypeList()[0] #all nodes have same type - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lPageNode = list() #setPageNdDomId = set() #the set of DOM id diff --git a/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml.py b/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml.py index 904b2d0..2e8e96c 100644 --- a/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml.py +++ b/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -70,7 +59,7 @@ def __init__(self, lNode = [], lEdge = []): if nt.getXpathExpr() != nt0.getXpathExpr(): raise ValueError("FactorialCRF requires all NodeType to have same Xpath selection expressions.") # --- Labels ---------------------------------------------------------- - def parseDomLabels(self): + def parseDocLabels(self): """ Parse the label of the graph from the dataset, and set the node label return the set of observed class (set of integers in N+) @@ -81,7 +70,7 @@ def parseDomLabels(self): setSeensLabels = set() for nd in self.lNode: try: - lcls = [self._dClsByLabel[nodeType.parseDomNodeLabel(nd.node)] for nodeType in self.getNodeTypeList()] + lcls = [self._dClsByLabel[nodeType.parseDocNodeLabel(nd)] for nodeType in self.getNodeTypeList()] except KeyError: raise ValueError("Page %d, unknown label in %s (Known labels are %s)"%(nd.pnum, str(nd.node), self._dClsByLabel)) nd.cls = lcls @@ -89,7 +78,7 @@ def parseDomLabels(self): setSeensLabels.add(cls) return setSeensLabels - def setDomLabels(self, Y): + def setDocLabels(self, Y): """ Set the labels of the graph nodes from the Y matrix return the DOM @@ -102,11 +91,11 @@ def setDomLabels(self, Y): for nodeType in self.getNodeTypeList(): for i,nd in enumerate(self.lNode): sLabel = self._dLabelByCls[ Y[zeroType+i] ] - nodeType.setDomNodeLabel(nd.node, sLabel) + nodeType.setDocNodeLabel(nd, sLabel) zeroType += nbNode return self.doc - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -124,7 +113,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): nodeType0 = self.getNodeTypeList()[0] #all nodes have same type - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lPageNode = list() #setPageNdDomId = set() #the set of DOM id diff --git a/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml_Scaffold.py b/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml_Scaffold.py index 876723d..eb4a42b 100644 --- a/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml_Scaffold.py +++ b/TranskribusDU/graph/factorial/FactorialGraph_MultiPageXml_Scaffold.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter.py new file mode 100644 index 0000000..1f24b66 --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- + +""" + Train, test, predict steps for a graph-based model using a binary conjugate + (two classes on the primal edges) + + Structured machine learning, currently using graph-CRF or Edge Convolution Network + + Copyright NAVER(C) 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import numpy as np + +from graph.GraphConjugateSegmenter import GraphConjugateSegmenter +from graph.Cluster import ClusterList, Cluster + +class GraphBinaryConjugateSegmenter(GraphConjugateSegmenter): + """ + What is independent from the input format (XML or JSON, currently) + """ + _balancedWeights = False # Uniform does same or better, in general + + lEdgeLabel = ["continue", "break"] + nbEdgeLabel = 2 + + sOutputAttribute = "DU_cluster" + + def __init__(self, sOuputXmlAttribute=None): + """ + a CRF model, with a name and a folder where it will be stored or retrieved from + + the cluster index of each object with be store in an Xml attribute of given name + """ + super(GraphBinaryConjugateSegmenter, self).__init__() + + self.dCluster = None + self.sClusteringAlgo = None + + if not sOuputXmlAttribute is None: + GraphBinaryConjugateSegmenter.sOutputAttribute = sOuputXmlAttribute + + def parseDocLabels(self): + """ + Parse the label of the graph from the dataset, and set the node label + return the set of observed class (set of integers in N+) + + Here, no check at all, because we just need to see if two labels are the same or not. + """ + setSeensLabels = set() + for nd in self.lNode: + nodeType = nd.type + sLabel = nodeType.parseDocNodeLabel(nd) + try: + cls = self._dClsByLabel[sLabel] #Here, if a node is not labelled, and no default label is set, then KeyError!!! + except KeyError: + cls = len(self._dClsByLabel) + self._dClsByLabel[sLabel] = cls + nd.cls = cls + setSeensLabels.add(cls) + return setSeensLabels + + def computeEdgeLabels(self): + """ + Given the loaded graph with labeled nodes, compute the edge labels. + + This results in each edge having a .cls attribute. + + return the set of observed class (set of integers in N+) + """ + setSeensLabels = set() + for edge in self.lEdge: + edge.cls = 0 if (edge.A.cls == edge.B.cls) else 1 + setSeensLabels.add(edge.cls) + return setSeensLabels + + # --- Clusters ------------------------------------------------------- + def form_cluster(self, Y_proba, fThres=0.5, bAgglo=True): + """ + Do a connected component algo + Return a dictionary: cluster_num --> [list of node index in lNode] + """ + + if bAgglo: + lCluster = self.agglomerative_clustering(0.99,Y_proba) + else: + # need binary edge labels + Y = Y_proba.argmax(axis=1) + lCluster = self.connected_component(Y, fThres) + + return lCluster + + def connected_component(self, Y, fThres): + import sys + recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(2*recursion_limit) + + # create clusters of node based on edge binary labels + try: + def DFS(i): + visited[i] = 1 + visited_index.append(i) + for j in range(nb_node): + if visited[j] != 1 and ani[i][j]==1: + visited_index.append(j) + DFS(j) + return visited_index + + lCluster = ClusterList([], "CC") # CC stands for COnnected Component + + nb_node = len(self.lNode) + + # create an adjacency matrix + ani = np.zeros(shape=(nb_node, nb_node), dtype='int64') + for i, edge in enumerate(self.lEdge): + if Y[i] < fThres: + # connected! + iA, iB = edge.A._index, edge.B._index + ani[iA,iB] = 1 + ani[iB,iA] = 1 + + visited = np.zeros(nb_node, dtype='int64') + for i in range(nb_node): + visited_index = [] + if visited[i] == 0: + lCluster.append(Cluster(DFS(i))) + finally: + sys.setrecursionlimit(recursion_limit) + + return lCluster + + + def getEdges(self,lEdges,Y_proba): + """ + return a dictionary of edges + """ + #lLabels=['continue','break'] + dEdges={0:{},1:{} } + Y = Y_proba.argmax(axis=1) + for i, edge in enumerate(lEdges): + #type(edge) in [HorizontalEdge, VerticalEdge]: + #cls = Y[i] + dEdges[ Y[i] ] [(edge.A,edge.B)]= Y_proba[i, Y[i]] + return dEdges + + def distance(self,c1,c2,relmat): + """ + compute the "distance" between an element and a cluster + distance = nbOk, nbBad edges + return distance + """ + iBad = 0 + iOK = 0 + for p in c1: + for pp in c2: + try: + if relmat[0][(p,pp)] >= 0.5: + iOK += relmat[0][(p,pp)] + else:iBad += relmat[0][(p,pp)] + except KeyError:#no edge + pass + try: + if relmat[1][(p,pp)] >= 0.5: + iBad += relmat[1][(p,pp)] + # else:iOK += 1 # possible? + except KeyError:#no edge + pass + return iOK,iBad + + def mergeCluster(self,lc,relmat): + """ + for each cluster: compute score with all other clusters + need to differentiate between H and V ??29/08/2019 + """ + + lDist = {} + for i,c in enumerate(lc): + for j,c2 in enumerate(lc[i+1:]): + dist = self.distance(c,c2,relmat) + if dist != (0,0): + # print(c,c2,dist) + if dist[0] > dist[1]: + lDist[(i,i+j+1)] = dist[0] - dist[1] + # sort + # merge if dist + sorted_x = sorted(lDist.items(), key=lambda v:v[1],reverse=True) + ltbdel=[] + lSeen=[] + for p,score in sorted_x: + a=p[0];b=p[1] + if lc[b] not in ltbdel: + if lc[a] in lSeen or lc[b] in lSeen: + pass + else: + lSeen.append(lc[a]) + lSeen.append(lc[b]) + lc[a].extend(lc[b][:]) + ltbdel.append(lc[b]) + [lc.remove(x) for x in ltbdel] + + + return lc, ltbdel !=[] + + + def assessCluster(self,c,relmat): + """ + coherent score + """ + iBad = 0 + iOK = 0 + for i,p in enumerate(c): + for pp in c[i+1:]: + try: + if relmat['continue'][(p,pp)] >= 0.5: + iOK += 1 #relmat['continue'][(p.getAttribute('id'),pp.getAttribute('id'))] + else:iBad += 1 #relmat['continue'][(p.getAttribute('id'),pp.getAttribute('id'))] + except KeyError:#no edge + pass + try: + if relmat['break'][(p,pp)] >= 0.5: + iBad += 1 #relmat['break'][(p.getAttribute('id'),pp.getAttribute('id'))] + # else:iOK += 1 + except KeyError:#no edge + pass + return iOK,iBad + + def clusterPlus(self,lCluster,dEdges): + """ + merge cluster as long as new clusters are created + """ + bGo=True + while bGo: + lCluster,bGo = self.mergeCluster(lCluster,dEdges) + + lCluster.sAlgo = 'agglo' + + return lCluster + + def agglomerative_clustering(self,fTH,Y_proba): + """ + fTH : threshold used for initial connected components run + Y_proba : edge prediction + Algo: perform a cc with fTH + merge clusters which share coherent set of edges (iteratively) + + return: set of clusters + """ + Y = Y_proba.argmax(axis=1) + ClusterList = self.connected_component(Y,fTH) + + dEdges = self.getEdges(self.lEdge,Y_proba) + lCluster = self.clusterPlus(ClusterList,dEdges) + + return lCluster + + + diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_DOM.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_DOM.py new file mode 100644 index 0000000..747d56d --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_DOM.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- + +""" + Train, test, predict steps for a graph-based model using a binary conjugate + (two classes on the primal edges) + + Structured machine learning, currently using graph-CRF or Edge Convolution Network + + Copyright NAVER(C) 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import lxml.etree as etree + +from xml_formats.PageXml import PageXml +from util.Shape import ShapeLoader + +from graph.Graph_DOM import Graph_DOM +from .GraphBinaryConjugateSegmenter import GraphBinaryConjugateSegmenter + + + +class GraphBinaryConjugateSegmenter_DOM(GraphBinaryConjugateSegmenter, Graph_DOM): + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + def __init__(self, lNode = [], lEdge = [], sOuputXmlAttribute=None): + GraphBinaryConjugateSegmenter.__init__(self, sOuputXmlAttribute=sOuputXmlAttribute) + Graph_DOM.__init__(self, lNode, lEdge) + + def addClusterToDoc(self, lCluster): + """ + DOM version + """ + for num, lNodeIdx in enumerate(lCluster): + for ndIdx in lNodeIdx: + node = self.lNode[ndIdx] + node.node.set(self.sOutputAttribute, "%d"%num) + + self.addClusterToDom(lCluster, sAlgo=lCluster.sAlgo) + return + + def addEdgeToDoc(self, Y_proba): + """ + To display the graph conveniently we add new Edge elements + + # for y_p, x_u, in zip(lY_pred, [X]): + # edges = x_u[1][:int(len(x_u[1])/2)] + # for i, (p,ie) in enumerate(zip(y_p, edges)): + # print(p, g.lNode[ie[0]].text,g.lNode[ie[1]].text, g.lEdge[i]) + """ + if self.lNode: + ndPage = self.lNode[0].page.node + ndPage.append(etree.Comment("\nEdges labeled by the conjugate graph\n")) + Y = Y_proba.argmax(axis=1) + for i, edge in enumerate(self.lEdge): + A, B = edge.A ,edge.B #shape.centroid, edge.B.shape.centroid + ndEdge = PageXml.createPageXmlNode("Edge") + try: + cls = Y[i] + ndEdge.set("label", self.lEdgeLabel[cls]) + ndEdge.set("proba", "%.3f" % Y_proba[i, cls]) + except IndexError: + # case of a conjugate graph without edge, so the edges + # of the original graph cannot be labelled + pass + ndEdge.set("src", edge.A.node.get("id")) + ndEdge.set("tgt", edge.B.node.get("id")) + ndEdge.set("type", edge.__class__.__name__) + ndPage.append(ndEdge) + ndEdge.tail = "\n" + PageXml.setPoints(ndEdge, [(A.x1, A.y1), (B.x1, B.y1)]) + + return + + def addClusterToDom(self, lCluster, bMoveContent=False, sAlgo="", pageNode=None): + """ + Add Cluster elements to the Page DOM node + """ + lNdCluster = [] + for name, lnidx in enumerate(lCluster): + #self.analysedCluster() + if pageNode is None: + for idx in lnidx: + pageNode = self.lNode[idx].page.node + break + pageNode.append(etree.Comment("\nClusters created by the conjugate graph\n")) + + ndCluster = PageXml.createPageXmlNode('Cluster') + ndCluster.set("name", str(name)) + ndCluster.set("algo", sAlgo) + # add the space separated list of node ids + ndCluster.set("content", " ".join(self.lNode[_i].node.get("id") for _i in lnidx)) + coords = PageXml.createPageXmlNode('Coords') + ndCluster.append(coords) + spoints = ShapeLoader.minimum_rotated_rectangle([self.lNode[_i].node for _i in lnidx]) + coords.set('points',spoints) + pageNode.append(ndCluster) + ndCluster.tail = "\n" + + if bMoveContent: + # move the DOM node of the content to the cluster + for _i in lnidx: + ndCluster.append(self.lNode[_i].node) + lNdCluster.append(ndCluster) + + return lNdCluster diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_JsonOCR.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_JsonOCR.py new file mode 100644 index 0000000..882927d --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/GraphBinaryConjugateSegmenter_JsonOCR.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +""" + Train, test, predict steps for a graph-based model using a binary conjugate + (two classes on the primal edges) + + Structured machine learning, currently using graph-CRF or Edge Convolution Network + + Copyright NAVER(C) 2019 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +from common.trace import traceln + +from graph.Graph_JsonOCR import Graph_JsonOCR +from .GraphBinaryConjugateSegmenter import GraphBinaryConjugateSegmenter + + +class GraphBinaryConjugateSegmenter_jsonOCR(GraphBinaryConjugateSegmenter, Graph_JsonOCR): + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + bWARN_TODO_addClusterToDoc = True + bWARN_TODO_addEdgeToDoc = True + + def __init__(self, lNode = [], lEdge = [], sOuputXmlAttribute=None): + GraphBinaryConjugateSegmenter.__init__(self, sOuputXmlAttribute=sOuputXmlAttribute) + Graph_JsonOCR.__init__(self, lNode, lEdge) + + def addClusterToDoc(self, dCluster): + """ + JSON OCR version + """ + # TODO + if self.bWARN_TODO_addClusterToDoc: + traceln(self, " addClusterToDoc ", "not implemented") + self.bWARN_TODO_addClusterToDoc = False + return + + + def addEdgeToDoc(self, Y_proba): + """ + JSON OCR version + """ + # TODO + if self.bWARN_TODO_addEdgeToDoc: + traceln(self, " addEdgeToDoc ", "not implemented") + self.bWARN_TODO_addEdgeToDoc = False + return diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml.py new file mode 100644 index 0000000..a32e622 --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +""" + Multi single PageXml graph in conjugate mode + + Copyright NAVER(C) 2019 JL. Meunier +""" + +from .GraphBinaryConjugateSegmenter_DOM import GraphBinaryConjugateSegmenter_DOM +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml + + +class MultiSinglePageXml( + GraphBinaryConjugateSegmenter_DOM + , Graph_MultiSinglePageXml): + """ + Multi single PageXml graph in conjugate mode + """ + # --- NODE TYPES and LABELS + _lNodeType = [] #the list of node types for this class of graph + _bMultitype = False # equivalent to len(_lNodeType) > 1 + _dLabelByCls = None #dictionary across node types + _dClsByLabel = None #dictionary across node types + _nbLabelTot = 0 #total number of labels + + def __init__(self): + super(MultiSinglePageXml, self).__init__() + diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml_Separator.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml_Separator.py new file mode 100644 index 0000000..9a6deb7 --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/MultiSinglePageXml_Separator.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +""" + Multi single PageXml graph in conjugate mode, exploting SeparatorRegion + as additional edge features + + Copyright NAVER(C) 2019 + + 2019-08-20 JL. Meunier +""" + +from .PageXmlSeparatorRegion import PageXmlSeparatorRegion +from .MultiSinglePageXml import MultiSinglePageXml + + +class MultiSinglePageXml_Separator(PageXmlSeparatorRegion, MultiSinglePageXml): + """ + Multi single PageXml graph in conjugate mode, exploting SeparatorRegion + as additional edge features + """ + def __init__(self): + super(MultiSinglePageXml_Separator, self).__init__() + diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/PageXmlSeparatorRegion.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/PageXmlSeparatorRegion.py new file mode 100644 index 0000000..347885a --- /dev/null +++ b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/PageXmlSeparatorRegion.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- + +""" + A class to load the SeparatorRegion of a PageXml to add features to the + edges of a graph conjugate used for segmentation. + + It specialises the _index method to add specific attributes to the edges + , so that the specific feature transformers can be used. + + Copyright NAVER(C) 2019 + + 2019-08-20 JL. Meunier +""" + +import numpy as np + +import shapely.geometry as geom +from shapely.prepared import prep +from rtree import index +from sklearn.pipeline import Pipeline +from sklearn.preprocessing.data import QuantileTransformer + +from common.trace import traceln +from util.Shape import ShapeLoader +from xml_formats.PageXml import PageXml + +from .GraphBinaryConjugateSegmenter_DOM import GraphBinaryConjugateSegmenter_DOM +from graph.Transformer import Transformer + + +class PageXmlSeparatorRegion(GraphBinaryConjugateSegmenter_DOM): + """ + Extension of a segmenter conjugate graph to exploit graphical separator + as additional edge features + """ + bVerbose = True + + def __init__(self): + super(PageXmlSeparatorRegion, self).__init__() + + def _index(self): + """ + This method is called before computing the Xs + We call it and right after, we compute the intersection of edge with SeparatorRegions + Then, feature extraction can reflect the crossing of edges and separators + """ + bFirstCall = super(PageXmlSeparatorRegion, self)._index() + + if bFirstCall: + # indexing was required + # , so first call + # , so we need to make the computation of edges crossing separators! + self.addSeparatorFeature() + + def addSeparatorFeature(self): + """ + We load the graphical separators + COmpute a set of shapely object + In turn, for each edge, we compute the intersection with all separators + + The edge features will be: + - boolean: at least crossing one separator + - number of crossing points + - span length of the crossing points + - average length of the crossed separators + - average distance between two crossings + """ + + # graphical separators + dNS = {"pc":PageXml.NS_PAGE_XML} + someNode = self.lNode[0] + ndPage = someNode.node.xpath("ancestor::pc:Page", namespaces=dNS)[0] + lNdSep = ndPage.xpath(".//pc:SeparatorRegion", namespaces=dNS) + loSep = [ShapeLoader.node_to_LineString(_nd) for _nd in lNdSep] + + if self.bVerbose: traceln(" %d graphical separators"%len(loSep)) + + # make an indexed rtree + idx = index.Index() + for i, oSep in enumerate(loSep): + idx.insert(i, oSep.bounds) + + # take each edge in turn and list the separators it crosses + nCrossing = 0 + for edge in self.lEdge: + # bottom-left corner to bottom-left corner + oEdge = geom.LineString([(edge.A.x1, edge.A.y1), (edge.B.x1, edge.B.y1)]) + prepO = prep(oEdge) + lCrossingPoints = [] + fSepTotalLen = 0 + for i in idx.intersection(oEdge.bounds): + # check each candidate in turn + oSep = loSep[i] + if prepO.intersects(oSep): + fSepTotalLen += oSep.length + oPt = oEdge.intersection(oSep) + if type(oPt) != geom.Point: + traceln('Intersection in not a point: skipping it') + else: + lCrossingPoints.append(oPt) + + if lCrossingPoints: + nCrossing += 1 + edge.bCrossingSep = True + edge.sep_NbCrossing = len(lCrossingPoints) + minx, miny, maxx, maxy = geom.MultiPoint(lCrossingPoints).bounds + edge.sep_SpanLen = abs(minx-maxx) + abs(miny-maxy) + edge.sep_AvgSpanSgmt = edge.sep_SpanLen / len(lCrossingPoints) + edge.sep_AvgSepLen = fSepTotalLen / len(lCrossingPoints) + else: + edge.bCrossingSep = False + edge.sep_NbCrossing = 0 + edge.sep_SpanLen = 0 + edge.sep_AvgSpanSgmt = 0 + edge.sep_AvgSepLen = 0 + + #traceln((edge.A.domid, edge.B.domid, edge.bCrossingSep, edge.sep_NbCrossing, edge.sep_SpanLen, edge.sep_AvgSpanSgmt, edge.sep_AvgSepLen)) + + + if self.bVerbose: + traceln(" %d (/ %d) edges crossing at least one graphical separator"%(nCrossing, len(self.lEdge))) + + +class Separator_boolean(Transformer): + """ + a boolean encoding indicating if the edge crosses a separator + """ + def transform(self, lO): + nb = len(lO) + a = np.zeros((nb, 1), dtype=np.float64) + for i, o in enumerate(lO): + if o.bCrossingSep: a[i,0] = 1 + return a + + def __str__(self): + return "- Separator_boolean %s (#1)" % (self.__class__) + + +class Separator_num(Pipeline): + """ + Node neighbour count feature quantiled + """ + nQUANTILE = 16 + + class Selector(Transformer): + """ + Characterising the neighborough by the number of neighbour before and after + """ + def transform(self, lO): + nb = len(lO) + a = np.zeros((nb, 4), dtype=np.float64) + for i, o in enumerate(lO): + a[i,:] = (o.sep_NbCrossing, o.sep_SpanLen, o.sep_AvgSpanSgmt, o.sep_AvgSepLen) + return a + + def __init__(self, nQuantile=None): + self.nQuantile = Separator_num.nQUANTILE if nQuantile is None else nQuantile + Pipeline.__init__(self, [ ('geometry' , Separator_num.Selector()) + , ('quantiled', QuantileTransformer(n_quantiles=self.nQuantile, copy=False)) #use in-place scaling + ]) + + def __str__(self): + return "- Separator_num %s (#4)" % (self.__class__) + diff --git a/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/__init__.py b/TranskribusDU/graph/pkg_GraphBinaryConjugateSegmenter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TranskribusDU/graph/tests/test_Graph_MultiPageXml.py b/TranskribusDU/graph/tests/test_Graph_MultiPageXml.py index a65fc1a..2065152 100644 --- a/TranskribusDU/graph/tests/test_Graph_MultiPageXml.py +++ b/TranskribusDU/graph/tests/test_Graph_MultiPageXml.py @@ -48,7 +48,7 @@ def test_RectangleFitting(): #load the block of each page, keeping the list of blocks of previous page lPrevPageNode = None - for pnum, page, domNdPage in obj._iter_Page_DomNode(doc): + for pnum, page, domNdPage in obj._iter_Page_DocNode(doc): #now that we have the page, let's create the node for each type! lPageNode = list() setPageNdDomId = set() #the set of DOM id diff --git a/TranskribusDU/tasks/ABP_LA.py b/TranskribusDU/tasks/ABP_LA.py index 24c5882..7a4f6fe 100644 --- a/TranskribusDU/tasks/ABP_LA.py +++ b/TranskribusDU/tasks/ABP_LA.py @@ -12,18 +12,7 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/DU_CRF_Task.py b/TranskribusDU/tasks/DU_CRF_Task.py index 5684a01..b73026b 100755 --- a/TranskribusDU/tasks/DU_CRF_Task.py +++ b/TranskribusDU/tasks/DU_CRF_Task.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016, 2017 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/DU_ECN_Task.py b/TranskribusDU/tasks/DU_ECN_Task.py index 45c57fc..5344d2a 100644 --- a/TranskribusDU/tasks/DU_ECN_Task.py +++ b/TranskribusDU/tasks/DU_ECN_Task.py @@ -5,18 +5,7 @@ Copyright NAVER(C) 2018, 2019 Hervé Déjean, Jean-Luc Meunier, Animesh Prasad - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -79,7 +68,7 @@ def updateStandardOptionsParser(cls, parser): parser.add_option("--ecn" , dest='bECN' , action="store_true" , default=False, help="use ECN Models") parser.add_option("--ecn_config" , dest='ecn_json_config' , action="store", type="string" - , help="The Config files for the ECN Model") + , help="The config file for the ECN Model") parser.add_option("--baseline" , dest='bBaseline' , action="store_true" , default=False, help="report baseline method") @@ -95,11 +84,8 @@ def getStandardLearnerConfig(self, options): djson = json.loads(f.read()) if "ecn_learner_config" in djson: dLearnerConfig=djson["ecn_learner_config"] - elif "ecn_ensemble" in djson: - dLearnerConfig = djson else: raise Exception("Invalid config JSON file") - else: dLearnerConfig = { "name" :"default_8Lay1Conv", @@ -110,16 +96,154 @@ def getStandardLearnerConfig(self, options): "mu" : 0.0001, "nb_iter" : 1200, "nconv_edge" : 1, - "node_indim" : -1, + # until AUg 29, 2019 "node_indim" : -1, + "node_indim" : 64, + "num_layers" : 8, + "ratio_train_val" : 0.1, + "patience" : 50, + "activation_name" :"relu", + "stack_convolutions" : False + } + + if options.max_iter: + traceln(" - max_iter=%d" % options.max_iter) + dLearnerConfig["nb_iter"] = options.max_iter + + return dLearnerConfig + + +class DU_Ensemble_ECN_Task(DU_Task): + """ + DU learner based on Ensemble ECN + """ + VERSION = "ECN_v19" + + version = None # dynamically computed + + def __init__(self, sModelName, sModelDir + , sComment = None + , cFeatureDefinition = None + , dFeatureConfig = {} + , cModelClass = None + ): + super().__init__(sModelName, sModelDir, sComment, cFeatureDefinition, dFeatureConfig) + + global DU_Model_ECN + DU_Model_ECN = DU_Task.DYNAMIC_IMPORT('.DU_Model_ECN', 'gcn').DU_Ensemble_ECN + + self.cModelClass = DU_Model_ECN if cModelClass == None else cModelClass + assert issubclass(self.cModelClass, graph.GraphModel.GraphModel), "Your model class must inherit from graph.GraphModel.GraphModel" + + @classmethod + def getVersion(cls): + cls.version = "-".join([DU_Task.getVersion(), str(cls.VERSION)]) + return cls.version + + @classmethod + def updateStandardOptionsParser(cls, parser): + usage = """ + --ecn_ensemble Enable Edge Convolutional Network learning + --ecn_ensemble_config Path to the JSON configuration file (required!) for ECN learning + """ + #FOR GCN + parser.add_option("--ecn_ensemble" , dest='bECNEnsemble' + , action="store_true" + , default=False, help="use Ensemble ECN Models") + parser.add_option("--ecn_ensemble_config" , dest='ecn_ensemble_json_config' + , action="store", type="string" + , help="The config file for the Ensemble ECN Model") + return usage, parser + + def getStandardLearnerConfig(self, options): + """ + Once the command line has been parsed, you can get the standard learner + configuration dictionary from here. + """ + if options.ecn_ensemble_json_config: + with open(options.ecn_ensemble_json_config) as f: + djson = json.loads(f.read()) + if "ecn_ensemble" in djson: + dLearnerConfig = djson + else: + raise Exception("Invalid config JSON file for ensemble ECN model.") + else: + dLearnerConfig = { + "_comment":"1 relu and 1 tanh models, twice", + "ratio_train_val": 0.2, + "ecn_ensemble": [ + { + "type": "ecn", + "name" :"default_8Lay1Conv_A", + "dropout_rate_edge" : 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node" : 0.2, + "lr" : 0.0001, + "mu" : 0.0001, + "nb_iter" : 1200, + "nconv_edge" : 1, + "node_indim" : 64, "num_layers" : 8, "ratio_train_val" : 0.1, "patience" : 50, "activation_name" :"relu", "stack_convolutions" : False + }, + { + "type": "ecn", + "name" :"default_8Lay1Conv_A", + "dropout_rate_edge" : 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node" : 0.2, + "lr" : 0.0001, + "mu" : 0.0001, + "nb_iter" : 1200, + "nconv_edge" : 1, + "node_indim" : 64, + "num_layers" : 8, + "ratio_train_val" : 0.1, + "patience" : 50, + "activation_name" :"tanh", + "stack_convolutions" : False + }, + { + "type": "ecn", + "name" :"default_8Lay1Conv_B", + "dropout_rate_edge" : 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node" : 0.2, + "lr" : 0.0001, + "mu" : 0.0001, + "nb_iter" : 1200, + "nconv_edge" : 1, + "node_indim" : 64, + "num_layers" : 8, + "ratio_train_val" : 0.1, + "patience" : 50, + "activation_name" :"relu", + "stack_convolutions" : False + }, + { + "type": "ecn", + "name" :"default_8Lay1Conv_B", + "dropout_rate_edge" : 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node" : 0.2, + "lr" : 0.0001, + "mu" : 0.0001, + "nb_iter" : 1200, + "nconv_edge" : 1, + "node_indim" : 64, + "num_layers" : 8, + "ratio_train_val" : 0.1, + "patience" : 50, + "activation_name" :"tanh", + "stack_convolutions" : False } + ] + } if options.max_iter: traceln(" - max_iter=%d" % options.max_iter) dLearnerConfig["nb_iter"] = options.max_iter - return dLearnerConfig \ No newline at end of file + return dLearnerConfig diff --git a/TranskribusDU/tasks/DU_FactorialCRF_Task.py b/TranskribusDU/tasks/DU_FactorialCRF_Task.py index c303dd9..a1141f9 100644 --- a/TranskribusDU/tasks/DU_FactorialCRF_Task.py +++ b/TranskribusDU/tasks/DU_FactorialCRF_Task.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016, 2017 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/DU_GAT_Task.py b/TranskribusDU/tasks/DU_GAT_Task.py index 0a62337..94bb992 100644 --- a/TranskribusDU/tasks/DU_GAT_Task.py +++ b/TranskribusDU/tasks/DU_GAT_Task.py @@ -5,18 +5,7 @@ Copyright NAVER(C) 2018, 2019 Hervé Déjean, Jean-Luc Meunier, Animesh Prasad - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/DU_Table/DU_ABPTable.py b/TranskribusDU/tasks/DU_Table/DU_ABPTable.py new file mode 100644 index 0000000..12bb358 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_ABPTable.py @@ -0,0 +1,484 @@ +# -*- coding: utf-8 -*- + +""" + Example DU task for ABP Table + + Copyright Xerox(C) 2017 H. Déjean + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _checkFindColDir, _exit + +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from graph.NodeType_PageXml import NodeType_PageXml_type_woText +from tasks.DU_CRF_Task import DU_CRF_Task + + +#from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +#from crf.FeatureDefinition_PageXml_std_noText_v3 import FeatureDefinition_PageXml_StandardOnes_noText_v3 +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +import json + + +class DU_ABPTable(DU_CRF_Task): + """ + We will do a CRF model for a DU task + , with the below labels + """ + sXmlFilenamePattern = "*.mpxml" + + #sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.mpxml" + + sLabeledXmlFilenameEXT = ".mpxml" + + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + lLabels = ['RB', 'RI', 'RE', 'RS','RO'] + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen ] + print( len(lLabels) , lLabels) + print( len(lIgnoredLabels) , lIgnoredLabels) + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiSinglePageXml + nt = NodeType_PageXml_type_woText("abp" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + # ntA = NodeType_PageXml_type_woText("abp" #some short prefix because labels below are prefixed with it + # , lLabels + # , lIgnoredLabels + # , False #no label means OTHER + # ) + + nt.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + + # ntA.setXpathExpr( (".//pc:TextLine | .//pc:TextRegion" #how to find the nodes + # , "./pc:TextEquiv") #how to get their text + # ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + + if sComment is None: sComment = sModelName + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig = { } + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 8 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment= sComment + #,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText + ,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText + ) + + #self.setNbClass(3) #so that we check if all classes are represented in the training set + + if options.bBaseline: + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + #=== END OF CONFIGURATION ============================================================= + + + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_CRF_Task.predict(self, lsColDir) + + def runForExternalMLMethod(self, lsColDir, storeX, applyY, bRevertEdges=False): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_CRF_Task.runForExternalMLMethod(self, lsColDir, storeX, applyY, bRevertEdges) + + + +try: + from tasks.DU_ECN_Task import DU_ECN_Task + class DU_ABPTable_ECN(DU_ECN_Task): + """ + ECN Models + """ + + sMetadata_Creator = "NLE Document Understanding ECN" + sXmlFilenamePattern = "*.mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.mpxml" + + sLabeledXmlFilenameEXT = ".mpxml" + + dLearnerConfig = {'nb_iter': 50, + 'lr': 0.001, + 'num_layers': 3, + 'nconv_edge': 10, + 'stack_convolutions': True, + 'node_indim': -1, + 'mu': 0.0, + 'dropout_rate_edge': 0.0, + 'dropout_rate_edge_feat': 0.0, + 'dropout_rate_node': 0.0, + 'ratio_train_val': 0.15, + #'activation': tf.nn.tanh, Problem I can not serialize function HERE + } + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + lLabels = ['RB', 'RI', 'RE', 'RS', 'RO'] + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + print("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + print(len(lLabels), lLabels) + print(len(lIgnoredLabels), lIgnoredLabels) + + # DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiSinglePageXml + nt = NodeType_PageXml_type_woText("abp" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + nt.setXpathExpr((".//pc:TextLine" # how to find the nodes + , "./pc:TextEquiv") # how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + if sComment is None: sComment = sModelName + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig={} + , dLearnerConfig= dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment= sComment + , cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText_v3 + ) + + if options.bBaseline: + self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_ECN_Task.predict(self, lsColDir) +except ImportError: + print('Could not Load ECN Model, Is TensorFlow installed ?') + + +try: + from tasks.DU_ECN_Task import DU_ECN_Task + from gcn.DU_Model_ECN import DU_Model_GAT + class DU_ABPTable_GAT(DU_ECN_Task): + """ + ECN Models + """ + + sMetadata_Creator = "NLE Document Understanding GAT" + + + sXmlFilenamePattern = "*.mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.mpxml" + + sLabeledXmlFilenameEXT = ".mpxml" + + + dLearnerConfigOriginalGAT ={ + 'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 2,#2 Train Acc is lower 5 overfit both reach 81% accuracy on Fold-1 + 'nb_attention': 5, + 'stack_convolutions': True, + # 'node_indim': 50 , worked well 0.82 + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention': 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience": 50, + "mu": 0.00001, + "original_model" : True + + } + + dLearnerConfigNewGAT = {'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 5, + 'nb_attention': 5, + 'stack_convolutions': True, + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention' : 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience":50, + "original_model": False, + "attn_type":0 + } + dLearnerConfig = dLearnerConfigNewGAT + #dLearnerConfig = dLearnerConfigOriginalGAT + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + lLabels = ['RB', 'RI', 'RE', 'RS', 'RO'] + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + print("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + print(len(lLabels), lLabels) + print(len(lIgnoredLabels), lIgnoredLabels) + + # DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiSinglePageXml + nt = NodeType_PageXml_type_woText("abp" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + nt.setXpathExpr((".//pc:TextLine" # how to find the nodes + , "./pc:TextEquiv") # how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + if sComment is None : sComment= sModelName + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig={} + , dLearnerConfig= dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment=sComment + , cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText + , cModelClass=DU_Model_GAT + ) + + if options.bBaseline: + self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_ECN_Task.predict(self, lsColDir) +except ImportError: + print('Could not Load GAT Model','Is tensorflow installed ?') + + + +# ---------------------------------------------------------------------------- + +def main(sModelDir, sModelName, options): + if options.use_ecn: + if options.ecn_json_config is not None and options.ecn_json_config is not []: + f = open(options.ecn_json_config[0]) + djson=json.loads(f.read()) + dLearnerConfig=djson["ecn_learner_config"] + f.close() + doer = DU_ABPTable_ECN(sModelName, sModelDir,dLearnerConfigArg=dLearnerConfig) + + + + else: + doer = DU_ABPTable_ECN(sModelName, sModelDir) + elif options.use_gat: + if options.gat_json_config is not None and options.gat_json_config is not []: + + f = open(options.gat_json_config[0]) + djson=json.loads(f.read()) + dLearnerConfig=djson["gat_learner_config"] + f.close() + doer = DU_ABPTable_GAT(sModelName, sModelDir,dLearnerConfigArg=dLearnerConfig) + + else: + doer = DU_ABPTable_GAT(sModelName, sModelDir) + + else: + doer = DU_ABPTable(sModelName, sModelDir, + C = options.crf_C, + tol = options.crf_tol, + njobs = options.crf_njobs, + max_iter = options.max_iter, + inference_cache = options.crf_inference_cache) + + if options.rm: + doer.rm() + return + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] + + traceln("- classes: ", doer.getGraphClass().getLabelNameList()) + + ## use. a_mpxml files + doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern + + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, test_size=None, random_state=None, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm, options.pkl) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None, options.pkl) + import graph.GraphModel + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm, options.pkl) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for GridSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + if options.bDetailedReport: + traceln(tstReport.getDetailledReport()) + import graph.GraphModel + for test in lTst: + sReportPickleFilename = os.path.join('..',test, sModelName + "__report.pkl") + traceln('Report dumped into %s'%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, tstReport) + + if lRun: + if options.storeX or options.applyY: + try: doer.load() + except: pass #we only need the transformer + lsOutputFilename = doer.runForExternalMLMethod(lRun, options.storeX, options.applyY, options.bRevertEdges) + else: + doer.load() + lsOutputFilename = doer.predict(lRun) + + traceln("Done, see in:\n %s"%lsOutputFilename) + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) +# parser.add_option("--annotate", dest='bAnnotate', action="store_true",default=False, help="Annotate the textlines with BIES labels") + + #FOR GCN + parser.add_option("--revertEdges", dest='bRevertEdges', action="store_true", help="Revert the direction of the edges") + parser.add_option("--detail", dest='bDetailedReport', action="store_true", default=False,help="Display detailled reporting (score per document)") + parser.add_option("--baseline", dest='bBaseline', action="store_true", default=False, help="report baseline method") + parser.add_option("--ecn",dest='use_ecn',action="store_true", default=False, help="wether to use ECN Models") + parser.add_option("--ecn_config", dest='ecn_json_config',action="append", type="string", help="The Config files for the ECN Model") + parser.add_option("--gat", dest='use_gat', action="store_true", default=False, help="wether to use ECN Models") + parser.add_option("--gat_config", dest='gat_json_config', action="append", type="string", + help="The Config files for the Gat Model") + parser.add_option("--ecn_wedge", dest='ecn_wedge', action="store_true",default=False, help="save Wedge parameters") + # --- + #parse the command line + (options, args) = parser.parse_args() + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + main(sModelDir, sModelName, options) diff --git a/TranskribusDU/tasks/DU_Table/DU_ABPTableCutAnnotator.py b/TranskribusDU/tasks/DU_Table/DU_ABPTableCutAnnotator.py new file mode 100644 index 0000000..4e68444 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_ABPTableCutAnnotator.py @@ -0,0 +1,1075 @@ +# -*- coding: utf-8 -*- + +""" + Find cuts of a page and annotate them based on the table separators + + Copyright Naver Labs Europe 2018 + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os +from optparse import OptionParser +import operator +from collections import defaultdict + +from lxml import etree +import numpy as np +import shapely.geometry as geom +import shapely.affinity + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from xml_formats.PageXml import MultiPageXml, PageXml + +from util.Polygon import Polygon +from util.Shape import ShapeLoader, PolygonPartition + +from tasks.DU_Table.DU_ABPTableSkewed_CutAnnotator import _isBaselineNotO, _isBaselineInTable,\ + computePRF +from tasks.DU_Table.DU_ABPTableRCAnnotation import computeMaxRowSpan +from util.partitionEvaluation import evalPartitions +from util.jaccard import jaccard_distance + +class CutAnnotator: + """ + Cutting the page horizontally + """ + fRATIO = 0.66 + + def __init__(self): + pass + + def get_separator_YX_from_DOM(self, root, fMinPageCoverage): + """ + get the x and y of the GT table separators + return lists of y, for horizontal and of x for vertical separators, per page + return [(y_list, x_list), ...] + """ + ltlYlX = [] + for ndPage in MultiPageXml.getChildByName(root, 'Page'): + w, h = int(ndPage.get("imageWidth")), int(ndPage.get("imageHeight")) + + lYi, lXi = [], [] + + l = MultiPageXml.getChildByName(ndPage,'TableRegion') + if len(l) != 1: + if l: + traceln("** warning ** %d TableRegion instead of expected 1" % len(l)) + else: + traceln("** warning ** no TableRegion, expected 1") + if l: + for ndTR in l: + #enumerate the table separators + for ndSep in MultiPageXml.getChildByName(ndTR,'SeparatorRegion'): + sPoints=MultiPageXml.getChildByName(ndSep,'Coords')[0].get('points') + [(x1,y1),(x2,y2)] = Polygon.parsePoints(sPoints).lXY + + dx, dy = abs(x2-x1), abs(y2-y1) + if dx > dy: + #horizontal table line + if dx > (fMinPageCoverage*w): + #ym = (y1+y2)/2.0 # 2.0 to support python2 + lYi.append((y1,y2)) + else: + if dy > (fMinPageCoverage*h): + #xm = (x1+x2)/2.0 + lXi.append((x1,x2)) + ltlYlX.append( (lYi, lXi) ) + + return ltlYlX + + def getHisto(self, lNd, w, _fMinHorizProjection, h, _fMinVertiProjection + , fRatio=1.0 + , fMinHLen=None): + """ + + return two Numpy array reflecting the histogram of projections of objects + first array along Y axis (horizontal projection), 2nd along X axis + (vertical projection) + + when fMinHLen is given , we do not scale horizontally text shorter than fMinHLen + """ + + hy = np.zeros((h,), np.float) + hx = np.zeros((w,), np.float) + + for nd in lNd: + sPoints=MultiPageXml.getChildByName(nd,'Coords')[0].get('points') + try: + x1,y1,x2,y2 = Polygon.parsePoints(sPoints).fitRectangle() + + if fMinHLen is None or abs(x2-x1) > fMinHLen: + _x1, _x2 = self.scale(x1, x2, fRatio) + else: + _x1, _x2 = x1, x2 + _y1, _y2 = self.scale(y1, y2, fRatio) + hy[_y1:_y2+1] += float(x2 - x1) / w + hx[_x1:_x2+1] += float(y2 - y1) / h + except ZeroDivisionError: + pass + except ValueError: + pass + + return hy, hx + + @classmethod + def scale(cls, a, b, fRatio): + """ + a,b are integers + apply a scaling factor to the segment + make sure its length remains non-zero + return 2 integers + """ + if fRatio == 1.0: return (a,b) # the code below does it, but no need... + + l = b - a # signed length + ll = int(round(l * fRatio)) # new signed length + + dl2 = (l - ll) / 2.0 + ll2a = int(round(dl2)) + ll2b = (l - ll) - ll2a + + return a + ll2a, b - ll2b + + # labels... + def _getLabel(self, i,j, liGT): + """ + i,j are the index of teh start and end of interval of zeros + liGT is a list of pair of pixel coordinates + + an interval of zeros is positive if it contains either end of the + separator or its middle. + """ + for iGT, jGT in liGT: + mGT = (iGT+jGT) // 2 + if i <= iGT and iGT <= j: + return "S" + elif i <= jGT and jGT <= j: + return "S" + elif i <= mGT and mGT <= j: + return "S" + return "O" + + def getCentreOfZeroAreas(self, h, liGT=None): + """ + liGT is the groundtruth indices + return a list of center of areas contains consecutive 0s + """ + lij = [] #list of area indices + + i0 = None # index of start of a 0 area + imax = h.shape[0] + i = 0 + while i < imax: + if i0 is None: # we were in a non-zero area + if h[i] <= 0: i0 = i # start of an area of 0s + else: # we were in a zero area + if h[i] > 0: + # end of area of 0s + lij.append((i0, i-1)) + i0 = None + i += 1 + if not i0 is None: + lij.append((i0, imax-1)) + + + if liGT is None: + liLbl = [None] * len(lij) + else: + liLbl = [self._getLabel(i,j,liGT) for (i,j) in lij] + + #take middle + li = [ (j + i) // 2 for (i,j) in lij ] + + return li, liLbl + + def getLowestOfZeroAreas(self, h, liGT=None): + """ + liGT is the groundtruth indices + return a list of lowest points of areas contains consecutive 0s + """ + lijm = [] #list of area indices + + i0 = None # index of start of a 0 area + imax = h.shape[0] + i = 0 + minV, minI = None, None + while i < imax: + if i0 is None: # we were in a non-zero area + if h[i] <= 0: + i0 = i # start of an area of 0s + minV, minI = h[i0], i0 + else: # we were in a zero area + if h[i] > 0: + # end of area of 0s + lijm.append((i0, i-1, minI)) + i0 = None + else: + if h[i] <= minV: # take rightmost + minV, minI = h[i], i + i += 1 + if not i0 is None: + minV, minI = h[i0], i0 + i = i0 + 1 + while i < imax: + if h[i] < minV: # tale leftmost + minV, minI = h[i], i + i += 1 + lijm.append((i0, imax-1, minI)) + + + if liGT is None: + liLbl = [None] * len(lijm) + else: + liLbl = [self._getLabel(i,j,liGT) for (i,j,_m) in lijm] + + #take middle + li = [ m for (_i,_j, m) in lijm ] + + return li, liLbl + + + + def add_cut_to_DOM(self, root, + fMinHorizProjection=0.05, + fMinVertiProjection=0.05, + ltlYlX=[] + , fRatio = 1.0 + , fMinHLen = None): + """ + for each page, compute the histogram of projection of text on Y then X + axis. + From this histogram, find cuts. + fMinProjection determines the threholds as a percentage of width (resp + height) of page. Any bin lower than it is considered as zero. + Map cuts to table separators to annotate them + Dynamically tune the threshold for cutting so as to reflect most separators + as a cut. + Tag them if ltlYlX is given + + ltlYlX is a list of (ltY1Y2, ltX1X2) per page. + ltY1Y2 is the list of (Y1, Y2) of horizontal separators, + ltX1X2 is the list of (X1, X2) of vertical separators. + + Modify the XML DOM by adding a separator cut, annotated if GT given + """ + domid = 0 #to add unique separator id + llX, llY = [], [] + for iPage, ndPage in enumerate(MultiPageXml.getChildByName(root, 'Page')): + try: + lYi, lXi = ltlYlX[iPage] + #except TypeError: + except: + lYi, lXi = [], [] + + w, h = int(ndPage.get("imageWidth")), int(ndPage.get("imageHeight")) + + #Histogram of projections + lndTexLine = MultiPageXml.getChildByName(ndPage, 'TextLine') + aYHisto, aXHisto = self.getHisto(lndTexLine, + w, fMinHorizProjection, + h, fMinVertiProjection + , fRatio + , fMinHLen=fMinHLen) + + aYHisto = aYHisto - fMinHorizProjection + aXHisto = aXHisto - fMinVertiProjection + + #find the centre of each area of 0s and its label + lY, lYLbl = self.getCentreOfZeroAreas(aYHisto, lYi) + # lX, lXLbl = self.getCentreOfZeroAreas(aXHisto, lXi) + lX, lXLbl = self.getLowestOfZeroAreas(aXHisto, lXi) + + traceln(lY) + traceln(lX) + + traceln(" - %d horizontal cuts" % len(lY)) + traceln(" - %d vertical cuts" % len(lX)) + + #ndTR = MultiPageXml.getChildByName(ndPage,'TableRegion')[0] + + # horizontal grid lines + for y, ylbl in zip(lY, lYLbl): + domid += 1 + self.addPageXmlSeparator(ndPage, ylbl, 0, y, w, y, domid) + + # Vertical grid lines + for x, xlbl in zip(lX, lXLbl): + domid += 1 + self.addPageXmlSeparator(ndPage, xlbl, x, 0, x, h, domid) + + llX.append(lX) + llY.append(lY) + + return (llY, llX) + + @classmethod + def addPageXmlSeparator(cls, nd, sLabel, x1, y1, x2, y2, domid): + ndSep = MultiPageXml.createPageXmlNode("CutSeparator") + if not sLabel is None: + # propagate the groundtruth info we have + ndSep.set("type", sLabel) + if abs(x2-x1) > abs(y2-y1): + ndSep.set("orient", "0") + else: + ndSep.set("orient", "90") + ndSep.set("id", "s_%d"%domid) + nd.append(ndSep) + ndCoord = MultiPageXml.createPageXmlNode("Coords") + MultiPageXml.setPoints(ndCoord, [(x1, y1), (x2, y2)]) + ndSep.append(ndCoord) + return ndSep + + def remove_cuts_from_dom(self, root): + """ + clean the DOM from any existing cut + return the number of removed cut lines + """ + lnd = MultiPageXml.getChildByName(root,'CutSeparator') + n = len(lnd) + for nd in lnd: + nd.getparent().remove(nd) + #check... + lnd = MultiPageXml.getChildByName(root,'CutSeparator') + assert len(lnd) == 0 + return n + + def loadPageCol(self, ndPage, fRatio + , shaper_fun=ShapeLoader.node_to_Point + , funIndex=lambda x: x._du_index): + """ + load the page, looking for Baseline + can filter by DU_row + return a list of shapely objects + , a dict of sorted list of objects, by column + + GT BUG: some Baseline are assigned to the wrong Cell + => we also fix this here.... + + """ + loBaseline = [] # list of Baseline shapes + i = 0 + + dsetTableByCol = defaultdict(set) # sets of object ids, by col + dsetTableDataByCol = defaultdict(set) # sets of object ids, by col + dO = {} + + dNodeSeen = {} + # first associate a unique id to each baseline and list them + lshapeCell = [] + lOrphanBaselineShape = [] + + lCells = MultiPageXml.getChildByName(ndPage, "TableCell") + maxHeaderRowSpan = computeMaxRowSpan(lCells) + traceln(" - maxHeaderRowSpan=", maxHeaderRowSpan) + for ndCell in lCells: + row, col = int(ndCell.get("row")), int(ndCell.get("col")) + rowSpan = int(ndCell.get("rowSpan")) + plg = ShapeLoader.node_to_Polygon(ndCell) + #ymin, ymax of polygon + lx = [_x for _x, _y in plg.exterior.coords] + xmin, xmax = min(lx), max(lx) + plg._row = row + plg._col = col + plg._xmin, plg._xmax = xmin, xmax + lshapeCell.append(plg) + + for nd in MultiPageXml.getChildByName(ndCell, "Baseline"): + nd.set("du_index", "%d" % i) + ndParent = nd.getparent() + dNodeSeen[ndParent.get('id')] = True + + # Baseline as a shapely object + try: + o = shaper_fun(nd) #make a LineString + except Exception as e: + traceln("ERROR: id=", nd.getparent().get("id")) + raise e + # scale the objects, as done when cutting!! + # useless currently since we make a Point... + o = shapely.affinity.scale(o, xfact=fRatio, yfact=fRatio) + + o._du_index = i + o._du_nd = nd + o._dom_id = nd.getparent().get("id") + loBaseline.append(o) + + # is this object in the correct cell??? + # We must use the centroid of the text box, otherwise a baseline + # may be assigned to the next row + # NOOO x = ShapeLoader.node_to_Polygon(ndParent).centroid.x + # we must look for the leftest coordinate + # NO CHECK FOR COLUMNS + + dsetTableByCol[col].add(funIndex(o)) + + if (row+rowSpan) > maxHeaderRowSpan: + dsetTableDataByCol[col].add(funIndex(o)) + + i += 1 + +# if lOrphanBaselineShape: +# traceln(" *** error: %d Baseline in incorrect row - fixing this..." % len(lOrphanBaselineShape)) +# for o in lOrphanBaselineShape: +# bestrow, bestdeltacol = 0, 9999 +# try: +# y = o.y +# except: +# y = o.centroid.y +# for plg in lshapeCell: +# if plg._ymin <= y and y <= plg._ymax: +# # sounds good +# deltacol = abs(o._bad_cell._col - plg._col) +# if deltacol == 0: +# # same column, ok it is that one +# bestrow = plg._row +# break +# else: +# if bestdeltacol > deltacol: +# bestdeltacol = deltacol +# bestrow = plg._row +# traceln("\t id=%s misplaced in row=%s instead of row=%s" %( +# o._du_nd.getparent().get("id") +# , o._bad_cell._row +# , bestrow)) +# dsetTableByCol[bestrow].add(o._du_index) +# del o._bad_cell + + # and (UGLY) process all Baseline outside any TableCell... + + for nd in MultiPageXml.getChildByName(ndPage, "Baseline"): + try: + dNodeSeen[nd.getparent().get('id')] + except: + #OLD "GOOD" CODE HERE + nd.set("du_index", "%d" % i) + + # Baseline as a shapely object + o = shaper_fun(nd) #make a LineString + + # scale the objects, as done when cutting!! + o = shapely.affinity.scale(o, xfact=fRatio) + + o._du_index = i + o._du_nd = nd + o._dom_id = nd.getparent().get("id") + loBaseline.append(o) + + i += 1 + + return loBaseline, dsetTableByCol, dsetTableDataByCol, maxHeaderRowSpan + + +class NoSeparatorException(Exception): + pass + +class BaselineCutAnnotator(CutAnnotator): + """ + Much simpler approach: + - a block is defined by its baseline. + - the baseline of each block defines a possible cut + - a parameter defines if the corresponding block is above or below the cut + - so a cut defines a partition of the page block + + We use the table annotation to determine the baseline that is the on top + or bottom of each table line (or column) + """ + + bSIO = False # by default, we use SO as labels + #iModulo = 1 + + def __init__(self, bCutIsBeforeText=True): + CutAnnotator.__init__(self) + self.bCutIsBeforeText = bCutIsBeforeText + + #self._fModulo = float(self.iModulo) + + @classmethod + def setLabelScheme_SIO(cls): + cls.bSIO = True + return True + +# def setModulo(self, iModulo): +# self.iModulo = iModulo +# self._fModulo = float(self.iModulo) + +# def moduloSnap(self, x, y): +# """ +# return the same coordinate modulo the current modulo +# """ +# return (int(round(x / self.fModulo)) * self.iModulo, +# int(round(y / self.fModulo)) * self.iModulo) + + @classmethod + def getDomBaselineXY(cls, domNode): + """ + find the baseline descendant node and return its "central" point + """ + try: + ndBaseline = MultiPageXml.getChildByName(domNode,'Baseline')[0] + except IndexError as e: + traceln("WARNING: No Baseline child in ", domNode.get('id')) + raise e + x, y = cls.getPolylineAverageXY(ndBaseline) + # modulo should be done only after the GT assigns labels. + return (x, y) + + @classmethod + def getPolylineAverageXY(cls, ndPolyline): + """ + weighted average X and average Y of a polyline + the weight indicate how long each segment at a given X, or Y, was. + """ + sPoints=ndPolyline.get('points') + lXY = Polygon.parsePoints(sPoints).lXY + + # list of X and Y values and respective weights + lXYWxWy = [((x1+x2)/2.0, abs(y2-y1), # for how long at this X? + (y1+y2)/2.0, abs(x2-x1)) \ + for (x1,y1), (x2, y2) in zip(lXY, lXY[1:])] + fWeightedSumX = sum(x*wx for x, wx, _, _ in lXYWxWy) + fWeightedSumY = sum(y*wy for _, _, y, wy in lXYWxWy) + fSumWeightX = sum( wx for _, wx , _, _ in lXYWxWy) + fSumWeightY = sum( wy for _, _ , _, wy in lXYWxWy) + + Xavg = int(round(fWeightedSumX/fSumWeightX)) if fSumWeightX > 0 else 0 + Yavg = int(round(fWeightedSumY/fSumWeightY)) if fSumWeightY > 0 else 0 + +# Xavg, Yavg = self.moduloSnap(Xavg, Yavg) + + return (Xavg, Yavg) + + def _getLabelFromSeparator(self, ltXY, tlYlX, w, h): + """ + ltXY is the list of (X, Y) of the "central" point of each baseline + tlYlX are the coordinates of the GT separators + ltY1Y2 is the list of (Y1, Y2) of horizontal separators, + ltX1X2 is the list of (X1, X2) of vertical separators. + w, h are the page width and height + + if self.bCutIsBeforeText is True, we look for the highest baseline below + or on each separator (which is possibly not horizontal) + + if self.bCutIsBeforeText is False, we look for the lowest baseline above + or on each separator (which is possibly not horizontal) + + #TODO + Same idea for vertical separators ( ***** NOT DONE ***** ) + + return lX, lY, lXLbl, lYLbl + """ + ltY1Y2, ltX1X2 = tlYlX + + #rough horizontal and vertical bounds + try: + ymin = operator.add(*min(ltY1Y2)) / 2.0 # ~~ (miny1+miny2)/2.0 + ymax = operator.add(*max(ltY1Y2)) / 2.0 + xmin = operator.add(*min(ltX1X2)) / 2.0 + xmax = operator.add(*max(ltX1X2)) / 2.0 + except ValueError: + raise NoSeparatorException("No groundtruth") + + # find best baseline for each table separator + setBestY = set() + for (y1, y2) in ltY1Y2: + bestY = 999999 if self.bCutIsBeforeText else -1 + bFound = False + for x, y in ltXY: + if x < xmin or xmax < x: # text outside table, ignore it + continue + #y of separator at x + ysep = int(round(y1 + float(y2-y1) * x / w)) + if self.bCutIsBeforeText: + if ysep <= y and y < bestY and y < ymax: + #separator is above and baseline is above all others + bestY, bFound = y, True + else: + if ysep >= y and y > bestY and y > ymin: + bestY, bFound = y, True + if bFound: + setBestY.add(bestY) + + setBestX = set() + for (x1, x2) in ltX1X2: + bestX = 999999 if self.bCutIsBeforeText else -1 + bFound = False + for x, y in ltXY: + if y < ymin or ymax < y: # text outside table, ignore it + continue + #x of separator at Y + xsep = int(round(x1 + float(x2-x1) * x / h)) + if self.bCutIsBeforeText: + if xsep <= x and x < bestX and x < xmax: + #separator is above and baseline is above all others + bestX, bFound = x, True + else: + if xsep >= x and x > bestX and x > xmin: + bestX, bFound = x, True + if bFound: + setBestX.add(bestX) + + # zero or one cut given a position + lY = list(set(y for _, y in ltXY)) # zero or 1 cut per Y + lY.sort() + lX = list(set(x for x, _ in ltXY)) # zero or 1 cut per X + lX.sort() + + if self.bSIO: + # O*, S, (S|I)*, O* + if setBestY: + lYLbl = [ ("S" if y in setBestY \ + else ("I" if ymin <= y and y <= ymax else "O")) \ + for y in lY] + else: + lYLbl = ["O"] * len(lY) # should never happen... + if setBestX: + lXLbl = [ ("S" if x in setBestX \ + else ("I" if xmin <= x and x <= xmax else "O")) \ + for x in lX] + else: + lXLbl = ["O"] * len(lX) # should never happen... + else: + # annotate the best baseline-based separator + lYLbl = [ ("S" if y in setBestY else "O") for y in lY] + lXLbl = [ ("S" if x in setBestX else "O") for x in lX] + + return lY, lYLbl, lX, lXLbl + + +# def _getLabelFromCells(self, ltXY, lCells): +# """ +# +# NOT FINISHED +# +# SOME spans are ignored, some not +# +# This is done when making the straight separator, based on their length. +# +# ltXY is the list of (X, Y) of the "central" point of each baseline +# lCells is the list of cells of the table +# +# For Y labels (horizontal cuts): +# - if self.bCutIsBeforeText is True, we look for the highest baseline of +# each table line. +# - if self.bCutIsBeforeText is False, we look for the lowest baseline of +# each table line. +# +# same idea for X labels (vertical cuts) +# +# returns the list of Y labels, the list of X labels +# """ +# +# lYLbl, lXLbl = [], [] +# +# traceln("DIRTY: ignore rowspan above 5") +# lCells = list(filter(lambda x: int(x.get('rowSpan')) < 5, lCells)) +# dBestByRow = collections.defaultdict(lambda _: None) # row->best_Y +# dBestByCol = collections.defaultdict(lambda _: None) # col->best_X +# +# dRowSep_lSgmt = collections.defaultdict(list) +# dColSep_lSgmt = collections.defaultdict(list) +# for cell in lCells: +# row, col, rowSpan, colSpan = [int(cell.get(sProp)) for sProp \ +# in ["row", "col", "rowSpan", "colSpan"] ] +# coord = cell.xpath("./a:%s" % ("Coords"),namespaces={"a":MultiPageXml.NS_PAGE_XML})[0] +# sPoints = coord.get('points') +# plgn = Polygon.parsePoints(sPoints) +# lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() +# +# #now the top segments contribute to row separator of index: row +# dRowSep_lSgmt[row].extend(lT) +# #now the bottom segments contribute to row separator of index: row+rowSpan +# dRowSep_lSgmt[row+rowSpan].extend(lB) +# +# dColSep_lSgmt[col].extend(lL) +# dColSep_lSgmt[col+colSpan].extend(lR) + + + def add_cut_to_DOM(self, root, ltlYlX=[]): + """ + for each page: + - sort the block by their baseline average y + - the sorted list of Ys defines the cuts. + + Tag them if ltlYlX is given + ltlYlX is a list of (ltY1Y2, ltX1X2) per page. + ltY1Y2 is the list of (Y1, Y2) of horizontal separators, + ltX1X2 is the list of (X1, X2) of vertical separators. + + Modify the XML DOM by adding a separator cut, annotated if GT given + """ + domid = 0 #to add unique separator id + + ltlYCutXCut = [] + for iPage, ndPage in enumerate(MultiPageXml.getChildByName(root, 'Page')): + w, h = int(ndPage.get("imageWidth")), int(ndPage.get("imageHeight")) + + # list of Ys of baselines, and indexing of block by Y + #list of (X,Y) + ltXY = [] + lndTexLine = MultiPageXml.getChildByName(ndPage, 'TextLine') + for ndBlock in lndTexLine: + try: + ltXY.append(self.getDomBaselineXY(ndBlock)) + except: + pass + + # Groundtruth if any + #lCells= MultiPageXml.getChildByName(ndPage, 'TableCell') + + # let's collect the segment forming the separators + try: + lY, lYLbl, lX, lXLbl = self._getLabelFromSeparator(ltXY, + ltlYlX[iPage], w, h) + except NoSeparatorException: + lX = list(set(x for x, _ in ltXY)) # zero or 1 cut per X + lY = list(set(y for _, y in ltXY)) # zero or 1 cut per Y + lX.sort() # to have a nice XML + lY.sort() + lXLbl = [None] * len(lX) + lYLbl = [None] * len(lY) + + ndTR = MultiPageXml.getChildByName(root,'TableRegion')[0] + + #Vertical grid lines + for y, ylbl in zip(lY, lYLbl): + domid += 1 + self.addPageXmlSeparator(ndTR, ylbl, 0, y, w, y, domid) + traceln(" - added %d horizontal cuts" % len(lX)) + + #horizontal grid lines + for x, xlbl in zip(lX, lXLbl): + domid += 1 + self.addPageXmlSeparator(ndTR, xlbl, x, 0, x, h, domid) + traceln(" - added %d vertical cuts" % len(lY)) + + ltlYCutXCut.append( ([y for _,y in ltXY], + [x for x,_ in ltXY])) + + return ltlYCutXCut + + +# ------------------------------------------------------------------ +def main(sFilename, sOutFilename, fMinHorizProjection=0.05, fMinVertiProjection=0.05 + , bBaselineFirst=False + , bBaselineLast=False + , bSIO=False): + + print("- cutting: %s --> %s"%(sFilename, sOutFilename)) + + # Some grid line will be O or I simply because they are too short. + fMinPageCoverage = 0.5 # minimum proportion of the page crossed by a grid line + # we want to ignore col- and row- spans + + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + if bBaselineFirst: + doer = BaselineCutAnnotator(bCutIsBeforeText=True) + if bSIO: doer.setLabelScheme_SIO() + elif bBaselineLast: + doer = BaselineCutAnnotator(bCutIsBeforeText=False) + if bSIO: doer.setLabelScheme_SIO() + else: + doer = CutAnnotator() + + print("doer=%s"%doer) + + #map the groundtruth table separators to our grid, per page (1 in tABP) + ltlYlX = doer.get_separator_YX_from_DOM(root, fMinPageCoverage) + + # Find cuts and map them to GT + # + if bBaselineFirst or bBaselineLast: + doer.add_cut_to_DOM(root, ltlYlX=ltlYlX) + else: + doer.add_cut_to_DOM(root, ltlYlX=ltlYlX, + fMinHorizProjection=fMinHorizProjection, + fMinVertiProjection=fMinVertiProjection,) + + #l_DU_row_Y, l_DU_row_GT = doer.predict(root) + + doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + print('Annotated cut separators added into %s'%sOutFilename) + +global_maxHeaderRowSpan = None +def _isBaselineInTableData(nd): + """ + a Baseline in a TableRegion belongs to a TableCell element + """ + global global_maxHeaderRowSpan + v = nd.getparent().getparent().get("row") + if v is None: + return False + else: + return int(v) >= global_maxHeaderRowSpan + + +def get_col_partition(doer, sxpCut, dNS + , sFilename, lFilterFun + , fRatio + , bVerbose=False + , funIndex=lambda x: x._du_index + ): + """ + return the GT partition in columns, as well as 1 partition per filter function + """ + global global_maxHeaderRowSpan + + if bVerbose: traceln("- loading %s"%sFilename) + parser = etree.XMLParser() + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + llsetRun = [] + + pnum = 0 + lndPage = MultiPageXml.getChildByName(root, 'Page') + assert len(lndPage) == 1, "NOT SUPPORTED: file has many pages - soorry" + for ndPage in lndPage: + pnum += 1 + if bVerbose: traceln(" - page %s - loading table GT" % pnum) + + loBaseline, dsetTableByCol, dsetTableDataByCol, global_maxHeaderRowSpan = doer.loadPageCol(ndPage, fRatio + , funIndex=funIndex) + + if bVerbose: traceln(" - found %d objects on page" % (len(loBaseline))) + + # make a dictionary of cumulative sets, and the set of all objects + lTableColK = sorted(dsetTableByCol.keys()) + lTableDataColK = sorted(dsetTableDataByCol.keys()) + if bVerbose: + traceln(" - found %d cols" % (len(lTableColK))) + traceln(" - found %d objects in the table" % (sum(len(v) for v in dsetTableByCol.values()))) + traceln(" - found %d objects in the table data" % (sum(len(v) for v in dsetTableDataByCol.values()))) + lNdCut = ndPage.xpath(sxpCut, namespaces=dNS) + if bVerbose: + traceln(" - found %d cuts" % (len(lNdCut))) + else: + traceln("- loaded %40s " % sFilename + , " %6d cols %6d 'S' cuts" % ( len(lTableColK) + , len(lNdCut)) + , " %6d objects %6d table objects" % ( + len(loBaseline) + , sum(len(v) for v in dsetTableByCol.values()) + ) + ) + loCut = [] + for ndCut in lNdCut: + #now we need to infer the bounding box of that object + (x1, y1), (x2, y2) = PageXml.getPointList(ndCut) #the polygon + # Create the shapely shape + loCut.append(geom.LineString([(x1, y1), (x2, y2)])) + + w,h = float(ndPage.get("imageWidth")), float(ndPage.get("imageHeight")) +# # Add a fictive cut at top of page +# loCut.append(geom.LineString([(0, 0), (w, 0)])) +# # Add a fictive cut at end of page +# loCut.append(geom.LineString([(0, h), (w, h)])) + + # order it by line centroid x + loCut.sort(key=lambda o: o.centroid.x) + + # dcumset is the GT!! + lsetGT = [dsetTableByCol[k] for k in lTableColK] # list of set of du_index + lsetDataGT = [dsetTableDataByCol[k] for k in lTableDataColK] + + # NOW, look at predictions + for filterFun in lFilterFun: + loBaselineInTable = [o for o in loBaseline if filterFun(o._du_nd)] + if bVerbose: traceln(" - %d objects on page predicted in table (%d out)" % ( + len(loBaselineInTable) + , len(loBaseline) - len(loBaselineInTable))) + + # Now create the list of partitions created by the Cuts + lsetRun = [] + partition = PolygonPartition(loBaselineInTable) + if True: # or bCutOnLeft: + #cut if above the text that led to its creation + setAllPrevIds = set([]) # cumulative set of what was already taken + for oCut in loCut: + lo = partition.getObjectOnRightOfLine(oCut) + setIds = set(funIndex(o) for o in lo) + #print(oCut.centroid.x, setIds) + if setAllPrevIds: + prevColIds = setAllPrevIds.difference(setIds) # content of previous row + if prevColIds: + #an empty set is denoting alternative cuts leading to same partition + lsetRun.append(prevColIds) + setAllPrevIds = setIds + else: + assert False, "look at this code..." +# #cut if below the text that led to its creation +# cumSetIds = set([]) # cumulative set +# for oCut in loCut: +# lo = partition.getObjectAboveLine(oCut) +# setIds = set(o._du_index for o in lo) +# rowIds = setIds.difference(cumSetIds) # only last row! +# if rowIds: +# #an empty set is denoting alternative cuts leading to same partition +# lsetRun.append(rowIds) +# cumSetIds = setIds +# _debugPartition("run", lsetRun) +# _debugPartition("ref", lsetGT) + llsetRun.append(lsetRun) + return lsetGT, lsetDataGT, llsetRun + + +def op_eval_col(lsFilename, fSimil, fRatio, bVerbose=False): + """ + We load the XML + - get the CutSeparator elements + - get the text objects (geometry=Baseline) + - + """ + global global_maxHeaderRowSpan + nOk, nErr, nMiss = 0, 0, 0 + + if fSimil is None: + #lfSimil = [ i / 100 for i in range(75, 101, 5)] + lfSimil = [ i / 100 for i in range(70, 101, 10)] + else: + lfSimil = [fSimil] + + # we use only BIO + separators + dOkErrMissOnlyCol = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissOnlyCol.update({'name':'OnlyCol' + , 'FilterFun':_isBaselineNotO}) + # we use the TableRegion + separators + dOkErrMissTableCol = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissTableCol.update({'name':'TableCol' + , 'FilterFun':_isBaselineInTable}) + + # we use the TableRegion excluding the header + separators + dOkErrMissTableDataCol = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissTableDataCol.update({'name':'TableDataCol' + , 'FilterFun':_isBaselineInTableData}) + + ldOkErrMiss = [dOkErrMissOnlyCol, dOkErrMissTableCol, dOkErrMissTableDataCol] + + lFilterFun = [d['FilterFun'] for d in ldOkErrMiss] + + # sxpCut = './/pc:CutSeparator[@orient="0" and @DU_type="S"]' #how to find the cuts + sxpCut = './/pc:CutSeparator[@orient="90"]' #how to find the cuts + dNS = {"pc":PageXml.NS_PAGE_XML} + + doer = CutAnnotator() + + traceln(" - Cut selector = ", sxpCut) + + # load objects: Baseline and Cuts + for n, sFilename in enumerate(lsFilename): + lsetGT, lsetDataGT, llsetRun = get_col_partition(doer, sxpCut, dNS + , sFilename, lFilterFun + , fRatio + , bVerbose=False + , funIndex=lambda x: x._du_index # simpler to view +# , funIndex=lambda x: x._dom_id # more precise + ) + pnum = 1 # only support single-page file... + for dOkErrMiss, lsetRun in zip(ldOkErrMiss, llsetRun): + if dOkErrMiss['name'] == "TableDataCol": + # we need to filter also the GT to discard the header from the column + _lsetGT = lsetDataGT + else: + _lsetGT = lsetGT + if bVerbose: + traceln("----- RUN ----- ") + for s in lsetRun: traceln("run ", sorted(s)) + traceln("----- REF ----- ") + for s in _lsetGT: traceln("ref ", sorted(s)) + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions(lsetRun, _lsetGT, fSimil, jaccard_distance) + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + if bVerbose or fSimil == 1.0: + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + traceln("%4d %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d %s page=%d" %( + n+1, dOkErrMiss['name'], fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + , os.path.basename(sFilename), pnum)) + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + + for dOkErrMiss in [dOkErrMissOnlyCol, dOkErrMissTableCol, dOkErrMissTableDataCol]: + traceln() + name = dOkErrMiss['name'] + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f " % (name, fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + return (nOk, nErr, nMiss) + +def test_scale(): + + assert (1,3) == CutAnnotator.scale(1, 3, 1.0) + assert (3,1) == CutAnnotator.scale(3, 1, 1.0) + + def symcheck(a, b, r, aa, bb): + assert (aa, bb) == CutAnnotator.scale(a, b, r), (a, b, r, aa, bb) + assert (bb, aa) == CutAnnotator.scale(b, a, r), (b, a, r, bb, aa) + symcheck(1, 2, 1.0, 1, 2) + symcheck(1, 1, 1.0, 1, 1) + symcheck(1, 10, 1.0, 1, 10) + + assert (2,7) == CutAnnotator.scale(0 , 10, 0.5) + assert (8,3) == CutAnnotator.scale(10, 0 , 0.5) + + assert (-2,-7) == CutAnnotator.scale(-0 , -10, 0.5) + assert (-8,-3) == CutAnnotator.scale(-10, -0 , 0.5) + + assert (1,1) == CutAnnotator.scale(1, 1, 0.33) + +# ------------------------------------------------------------------ +if __name__ == "__main__": + usage = "" + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--baseline_first", dest='bBaselineFirst', action="store_true", help="Cut based on first baeline of row or column") + parser.add_option("--SIO" , dest='bSIO' , action="store_true", help="SIO labels") + + # --- + #parse the command line + (options, args) = parser.parse_args() + + #load mpxml + sFilename = args[0] + try: + sOutFilename = args[1] + except: + sp, sf = os.path.split(sFilename) + sOutFilename = os.path.join(sp, "cut-" + sf) + try: + fMinH = float(args[2]) + except: + fMinH = None + if fMinH is None: + main(sFilename, sOutFilename, bBaselineFirst=options.bBaselineFirst, bSIO=options.bSIO) + else: + fMinV = float(args[4]) # specify none or both + main(sFilename, sOutFilename, fMinH, fMinV, bBaselineFirst=options.bBaselineFirst, bSIO=options.bSIO) + + + diff --git a/TranskribusDU/tasks/DU_Table/DU_ABPTableRCAnnotation.py b/TranskribusDU/tasks/DU_Table/DU_ABPTableRCAnnotation.py new file mode 100644 index 0000000..495fb5e --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_ABPTableRCAnnotation.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- + +""" + Annotate textlines for Table understanding (finding rows and columns) + + It tags the table header, vs data, vs other stuff. + + It ignore the binding cells (hack: rowspan >= 5 means binding...) + It then reads the cell borders, and does a linear interpolation by row to produce + the horizontal graphical lines of the table. + It adds a TableHLine elements in an element TableGraphicalLine of TableRegion. + + Copyright Naver Labs Europe 2017 + H. Déjean + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os, math +from lxml import etree + +import numpy as np + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from xml_formats.PageXml import MultiPageXml +from util.Polygon import Polygon +from common.trace import traceln + +from tasks.DU_Table.DU_Table_CellBorder import getCellsSeparators + +lLabelsBIESO_R = ['B', 'I', 'E', 'S', 'O'] #O? +lLabelsSM_C = ['M', 'S', 'O'] # single cell, multicells +#lLabels_OI = ['O','I'] # inside/outside a table +#lLabels_SPAN = ['rspan','cspan','nospan'] +lLabels_HEADER = ['D','CH', 'O'] + + +sDURow = "DU_row" +sDUCol = 'DU_col' +sDUHeader = 'DU_header' + +class TableAnnotationException(Exception): + pass + + +def tag_DU_row_col_header(root, lCells, maxRowSpan): + """ + Tag the XML nodes corresponding to those cells + Modify the XML DOM + """ + for cell in lCells: + + lText = MultiPageXml.getChildByName(cell,'TextLine') + + # HEADER WISE: D CH O + if int(cell.get('row')) < maxRowSpan: + [x.set(sDUHeader,lLabels_HEADER[1]) for x in lText] + else: + [x.set(sDUHeader,lLabels_HEADER[0]) for x in lText] + + # ROW WISE: B I E S O + if len(lText) == 0: + pass + if len(lText) == 1: + lText[0].set(sDURow,lLabelsBIESO_R[3]) + elif len(lText) > 1: + # lText.sort(key=lambda x:float(x.prop('y'))) + lText[0].set(sDURow,lLabelsBIESO_R[0]) + [x.set(sDURow,lLabelsBIESO_R[1]) for x in lText[1:-1]] + lText[-1].set(sDURow,lLabelsBIESO_R[2]) + # MultiPageXml.setCustomAttr(lText[0],"table","rtype",lLabelsBIESO_R[0]) + # MultiPageXml.setCustomAttr(lText[-1],"table","rtype",lLabelsBIESO_R[2]) + # [MultiPageXml.setCustomAttr(x,"table","rtype",lLabelsBIESO_R[1]) for x in lText[1:-1]] + + #COLUM WISE: M S O + lCoords = cell.xpath("./a:%s" % ("Coords"),namespaces={"a":MultiPageXml.NS_PAGE_XML}) + coord= lCoords[0] + sPoints=coord.get('points') + plgn = Polygon.parsePoints(sPoints) + (cx,cy,cx2,cy2) = plgn.getBoundingBox() + + for txt in lText: + lCoords = txt.xpath("./a:%s" % ("Coords"),namespaces={"a":MultiPageXml.NS_PAGE_XML}) + coord= lCoords[0] + sPoints=coord.get('points') + lsPair = sPoints.split(' ') + lXY = list() + for sPair in lsPair: + try: + (sx,sy) = sPair.split(',') + lXY.append( (int(sx), int(sy)) ) + except ValueError: + traceln("WARNING: invalid coord in TextLine id=%s IGNORED"%txt.get("id")) + ## HOW to define a CM element!!!! + if lXY: + (x1,y1,x2,y2) = Polygon(lXY).getBoundingBox() + if x2> cx2 and (x2 - cx2) > 0.75 * (cx2 - x1): + txt.set(sDUCol,lLabelsSM_C[0]) + else: + txt.set(sDUCol,lLabelsSM_C[1]) + else: + txt.set(sDUCol,lLabelsSM_C[-1]) + + # textline outside table + lRegions= MultiPageXml.getChildByName(root,'TextRegion') + for region in lRegions: + lText = MultiPageXml.getChildByName(region,'TextLine') + [x.set(sDURow,lLabelsBIESO_R[-1]) for x in lText] + [x.set(sDUCol,lLabelsSM_C[-1]) for x in lText] + [x.set(sDUHeader,lLabels_HEADER[-1]) for x in lText] + + return + +def removeSeparator(root): + lnd = MultiPageXml.getChildByName(root, 'SeparatorRegion') + n = len(lnd) + for nd in lnd: + nd.getparent().remove(nd) + return n + +def addSeparator(root, lCells): + """ + Add separator that correspond to cell boundaries + modify the XML DOM + """ + dRow, dCol = getCellsSeparators(lCells) + + try: + ndTR = MultiPageXml.getChildByName(root,'TableRegion')[0] + except IndexError: + raise TableAnnotationException("No TableRegion!!! ") + + lRow = sorted(dRow.keys()) + lB = [] + for row in lRow: + (x1, y1), (x2, y2) = dRow[row] + b = math.degrees(math.atan((y2-y1) / (x2-x1))) + lB.append(b) + + ndSep = MultiPageXml.createPageXmlNode("SeparatorRegion") + ndSep.set("orient", "horizontal angle=%.2f" % b) + ndSep.set("row", "%d" % row) + ndTR.append(ndSep) + ndCoord = MultiPageXml.createPageXmlNode("Coords") + MultiPageXml.setPoints(ndCoord, [(x1, y1), (x2, y2)]) + ndSep.append(ndCoord) + sStat = "\tHORIZONTAL: Average=%.1f° stdev=%.2f° min=%.1f° max=%.1f°" % ( + np.average(lB), np.std(lB), min(lB), max(lB) + ) + ndTR.append(etree.Comment(sStat)) + traceln(sStat) + + lCol = sorted(dCol.keys()) + lB = [] + for col in lCol: + (x1, y1), (x2, y2) = dCol[col] + b = 90 -math.degrees(math.atan((x2-x1) / (y2 - y1))) + lB.append(b) + ndSep = MultiPageXml.createPageXmlNode("SeparatorRegion") + ndSep.set("orient", "vertical %.2f" % b) + ndSep.set("col", "%d" % col) + ndTR.append(ndSep) + ndCoord = MultiPageXml.createPageXmlNode("Coords") + MultiPageXml.setPoints(ndCoord, [(x1, y1), (x2, y2)]) + ndSep.append(ndCoord) + sStat = "\tVERTICAL : Average=%.1f° stdev=%.2f° min=%.1f° max=%.1f°" % ( + np.average(lB), np.std(lB), min(lB), max(lB) + ) + ndTR.append(etree.Comment(sStat)) + traceln(sStat) + + return + + +def computeMaxRowSpan(lCells): + """ + compute maxRowSpan for Row 0 + ignore cells for which rowspan = #row + """ + nbRows = max(int(x.get('row')) for x in lCells) + try: + return max(int(x.get('rowSpan')) for x in filter(lambda x: x.get('row') == "0" and x.get('rowSpan') != str(nbRows+1), lCells)) + except ValueError : + return 1 + +# ------------------------------------------------------------------ +def main(lsFilename, lsOutFilename): + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + for sFilename, sOutFilename in zip(lsFilename, lsOutFilename): + doc = etree.parse(sFilename, parser) + root = doc.getroot() + + lCells= MultiPageXml.getChildByName(root,'TableCell') + if not lCells: + traceln("ERROR: no TableCell - SKIPPING THIS FILE!!!") + continue + + # default: O for all cells: all cells must have all tags! + for cell in lCells: + lText = MultiPageXml.getChildByName(cell,'TextLine') + [x.set(sDURow,lLabelsBIESO_R[-1]) for x in lText] + [x.set(sDUCol,lLabelsSM_C[-1]) for x in lText] + [x.set(sDUHeader,lLabels_HEADER[-1]) for x in lText] + + + if False: + # Oct' 2018 RV and JL decided that we keep the binding TextLine (if any!) + # ignore "binding" cells + # dirty... + # lCells = list(filter(lambda x: int(x.get('rowSpan')) < 5, lCells)) + # less dirty + maxrow = max(int(x.get('row')) for x in lCells) + binding_rowspan = max(5, maxrow * 0.8) + traceln(" - max row = %d => considering rowspan > %d as binding cells" + % (maxrow, binding_rowspan)) + lValidCell, lBindingCell = [], [] + for ndCell in lCells: + if int(ndCell.get('rowSpan')) < binding_rowspan: + lValidCell.append(ndCell) + else: + lBindingCell.append(ndCell) + nDiscarded = len(lBindingCell) + if nDiscarded > 1: traceln("**************** WARNING ****************") + traceln(" - %d cells discarded as binding cells" % nDiscarded) + for ndCell in lBindingCell: + ndCell.set("type", "table-binding") + lCells = lValidCell + + # FOR COLUMN HEADER: get max(cell[0,i].span) + maxRowSpan = computeMaxRowSpan(lCells) + + tag_DU_row_col_header(root, lCells, maxRowSpan) + + try: + removeSeparator(root) + addSeparator(root, lCells) + doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + traceln('annotation done for %s --> %s' % (sFilename, sOutFilename)) + except TableAnnotationException: + traceln("No Table region in file ", sFilename, " IGNORED!!") + + del doc + + +if __name__ == "__main__": + try: + if len(sys.argv) == 3: + # COMPATIBILITY MODE + #load mpxml + sFilename = sys.argv[1] + sOutFilename = sys.argv[2] + lsFilename = [sFilename] + lsOutFilename = [sOutFilename] + else: + #we expect a folder + sInput = sys.argv[1] + if os.path.isdir(sInput): + lsFilename = [os.path.join(sInput, "col", s) for s in os.listdir(os.path.join(sInput, "col")) if s.endswith(".mpxml") ] + if not lsFilename: + lsFilename = [os.path.join(sInput, "col", s) for s in os.listdir(os.path.join(sInput, "col")) if s.endswith(".pxml") ] + lsFilename.sort() + lsOutFilename = [ os.path.dirname(s) + os.sep + "c_" + os.path.basename(s) for s in lsFilename] + else: + traceln("%s is not a folder"%sys.argv[1]) + raise IndexError() + except IndexError: + traceln("Usage: %s ( input-file output-file | folder )" % sys.argv[0]) + exit(1) + + traceln(lsFilename) + traceln("%d files to be processed" % len(lsFilename)) + traceln(lsOutFilename) + + main(lsFilename, lsOutFilename) + + diff --git a/TranskribusDU/tasks/DU_Table/DU_ABPTableSkewed_CutAnnotator.py b/TranskribusDU/tasks/DU_Table/DU_ABPTableSkewed_CutAnnotator.py new file mode 100644 index 0000000..bd90739 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_ABPTableSkewed_CutAnnotator.py @@ -0,0 +1,1074 @@ +# -*- coding: utf-8 -*- + +""" + Find cuts of a page along different slopes + and annotate them based on the table row content (which defines a partition) + + Copyright Naver Labs Europe 2018 + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import sys, os +from optparse import OptionParser +import math +from collections import defaultdict, Counter + +from lxml import etree +import shapely.geometry as geom +import shapely.ops + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from xml_formats.PageXml import MultiPageXml , PageXml +from util.Shape import ShapeLoader, PolygonPartition +from util.partitionEvaluation import evalPartitions +from util.jaccard import jaccard_distance +from util.Polygon import Polygon + + +class SkewedCutAnnotator: + """ + Finding Skewed cuts and projecting GT to them + + Approach: + - a block is defined by its baseline. + - we look for sloped separator crossing the page + - each cut defines a partition + - we build a dictionary part-> correspondign cuts + - we select for eac key the cut with most frequent slope + + + We use the table annotation to determine the GT tag of the cuts + """ + + # slope rad deg + # 1% 0.009999667 0.572938698 + # 2% 0.019997334 1.145762838 + # 3% 0.029991005 1.718358002 + # 4% 0.039978687 2.290610043 + # 5% 0.049958396 2.862405226 + + # store angles as radians, so convert from degrees + lfANGLE = [math.radians(x) for x in [-2, -1, 0, +1, +2]] + lfANGLE = [math.radians(x) for x in [0]] + # lfANGLE = [math.radians(x) for x in (_x/10 for _x in range(-20, +21, 5))] + + #lfANGLE = [math.radians(90+x) for x in [0]] + + gt_n = 0 # how many valid Cut found? (valid = reflecting a row) + gt_nOk = 0 # how many GT table rows covered by a cut? + gt_nMiss = 0 # how many GT table rows not reflected by a cut? + nCut = 0 + def __init__(self, bCutAbove, lAngle=lfANGLE): + traceln("** SkewedCutAnnotator bCutAbove=%s Angles (°): %s" %(bCutAbove, [math.degrees(v) for v in lAngle])) + self.bCutAbove = bCutAbove # do we generate a cut line above or below each object? + self.lAngle = lAngle + + # --- GT statistics + @classmethod + def gtStatReset(cls): + cls.gt_n = 0 # how many valid Cut found? (valid = reflecting a row) + cls.gt_nOk = 0 # how many GT table rows covered by a cut? + cls.gt_nMiss = 0 # how many GT table rows not reflected by a cut? + + @classmethod + def gtStatAdd(cls, n, nOk, nMiss, nCut): + cls.gt_n += n + cls.gt_nOk += nOk + cls.gt_nMiss += nMiss + cls.nCut += nCut + + @classmethod + def gtStatReport(cls, t_n_nOk_nMiss=None): + try: + # to force displayign certain values + n, nOk, nMiss = t_n_nOk_nMiss + label = " >" + nCut = None + except: + n, nOk, nMiss = cls.gt_n, cls.gt_nOk, cls.gt_nMiss + nCut = cls.nCut + label = "summary" + nTotGT = nOk + nMiss + 0.00001 + traceln("GT: %s %7d cut reflecting a GT table row (%.2f%%)" % (label, n, 100 * n / nTotGT)) + traceln("GT: %s %7d GT table row reflected by a cut (%.2f%%)" % (label, nOk , 100*nOk / nTotGT)) + traceln("GT: %s %7d GT table row not reflected by a cut (%.2f%%)" % (label, nMiss, 100*nMiss / nTotGT)) + if not(nCut is None): + traceln("GT: %s %7d cuts in total (%.2f%%)" % (label, nCut, nCut/nTotGT*100)) + + +# # def loadPage(self, ndPage, shaper_fun=ShapeLoader.node_to_LineString): +# def loadPage_v1(self, ndPage, shaper_fun=ShapeLoader.node_to_Point): +# """ +# load the page, looking for Baseline +# can filter by DU_row +# return a list of shapely objects +# , a dict of sorted list of objects, by row +# +# GT BUG: some Baseline are assigned to the wrong Cell +# => we also fix this here.... +# +# """ +# loBaseline = [] # list of Baseline shapes +# +# dsetTableByRow = defaultdict(set) # sets of object ids, by row +# +# # first associate a unique id to each baseline and list them +# for i, nd in enumerate(MultiPageXml.getChildByName(ndPage, "Baseline")): +# nd.set("du_index", "%d" % i) +# # -> TextLine -> TableCell (possibly) +# ndPrnt = nd.getparent() +# row_lbl = ndPrnt.get("DU_row") +# row = ndPrnt.getparent().get("row") +# # row can be None +# +# # Baseline as a shapely object +# o = shaper_fun(nd) #make a LineString +# o._du_index = i +# o._du_row = row # can be None +# o._du_DU_row = row_lbl # can be None +# o._du_nd = nd +# loBaseline.append(o) +# +# if not row is None: +# dsetTableByRow[int(row)].add(i) +# +# return loBaseline, dsetTableByRow + + def loadPage(self, ndPage + , shaper_fun=ShapeLoader.node_to_Point + , funIndex=lambda x: x._du_index + , bIgnoreHeader=False + ): + """ + load the page, looking for Baseline + can filter by DU_row + return a list of shapely objects + , a dict of sorted list of objects, by row + + GT BUG: some Baseline are assigned to the wrong Cell + => we also fix this here.... + + """ + loBaseline = [] # list of Baseline shapes + i = 0 + + dsetTableByRow = defaultdict(set) # sets of object ids, by row + + dNodeSeen = {} + # first associate a unique id to each baseline and list them + lshapeCell = [] + lOrphanBaselineShape = [] + for ndCell in MultiPageXml.getChildByName(ndPage, "TableCell"): + row, col = ndCell.get("row"), ndCell.get("col") + plg = ShapeLoader.node_to_Polygon(ndCell) + #ymin, ymax of polygon + ly = [_y for _x, _y in plg.exterior.coords] + ymin, ymax = min(ly), max(ly) + plg._row = int(row) + plg._col = int(col) + plg._ymin, plg._ymax = ymin, ymax + + i0 = i + for nd in MultiPageXml.getChildByName(ndCell, "Baseline"): + nd.set("du_index", "%d" % i) + ndParent = nd.getparent() + dNodeSeen[ndParent.get('id')] = True + if bIgnoreHeader and ndParent.get("DU_header") == "CH": + continue + row_lbl = ndParent.get("DU_row") + + # Baseline as a shapely object + try: + o = shaper_fun(nd) #make a LineString + except Exception as e: + traceln("ERROR: id=", nd.getparent().get("id")) + raise e + o._du_index = i + o._du_DU_row = row_lbl # can be None + o._du_nd = nd + o._dom_id = nd.getparent().get("id") + loBaseline.append(o) + + # is this object in the correct cell??? + # We must use the centroid of the text box, otherwise a baseline + # may be assigned to the next row + #y = o.centroid.y # NOO!! + y = ShapeLoader.node_to_Polygon(ndParent).centroid.y + # if ymin <= y and y <= ymax: + # we allow the content of a cell to overlap the cell lower border + if ymin <= y: + dsetTableByRow[int(row)].add(funIndex(o)) + else: + # this is an orphan! + o._bad_cell = plg + lOrphanBaselineShape.append(o) + + i += 1 + + if bIgnoreHeader and i0 == i: + continue # empty cells, certainly due to headers, ignore it. + + lshapeCell.append(plg) + # end for + + if lOrphanBaselineShape: + traceln(" *** error: %d Baseline in incorrect row - fixing this..." % len(lOrphanBaselineShape)) + for o in lOrphanBaselineShape: + bestrow, bestdeltacol = 0, 9999 + try: + y = o.y + except: + y = o.centroid.y + for plg in lshapeCell: + if plg._ymin <= y and y <= plg._ymax: + # sounds good + deltacol = abs(o._bad_cell._col - plg._col) + if deltacol == 0: + # same column, ok it is that one + bestrow = plg._row + break + else: + if bestdeltacol > deltacol: + bestdeltacol = deltacol + bestrow = plg._row + traceln("\t id=%s misplaced in row=%s instead of row=%s" %( + o._du_nd.getparent().get("id") + , o._bad_cell._row + , bestrow)) + dsetTableByRow[bestrow].add(funIndex(o)) + del o._bad_cell + + # and (UGLY) process all Baseline outside any TableCell... + + for nd in MultiPageXml.getChildByName(ndPage, "Baseline"): + try: + dNodeSeen[nd.getparent().get('id')] + except: + #OLD "GOOD" CODE HERE + nd.set("du_index", "%d" % i) + # -> TextLine -> TableCell (possibly) + ndPrnt = nd.getparent() + row_lbl = ndPrnt.get("DU_row") + + # Baseline as a shapely object + o = shaper_fun(nd) #make a LineString + o._du_index = i + o._du_row = None # Must be None + o._du_DU_row = row_lbl # can be None + o._du_nd = nd + o._dom_id = nd.getparent().get("id") + + loBaseline.append(o) + + i += 1 + + return loBaseline, dsetTableByRow + + @classmethod + def makeCumulativeTableByRow(cls, dsetTableByRow, bDownward=False): + """ + get a dictionary row-index -> set of row objects + make a cumulative dictionary row-index -> frozenset of object from row 0 to row K + if bDonward is False, cumul is done from K to 0 + return (sorted list of keys, cumulative dictionary, set of all objects) + """ + dcumset = defaultdict(set) + cumset = set() # cumulative set of all index (of table objects) + lTableRowK = sorted(dsetTableByRow.keys(), reverse=bDownward) + for k in lTableRowK: + cumset.update(dsetTableByRow[k]) + dcumset[k] = frozenset(cumset) + if bDownward: lTableRowK.reverse() + return lTableRowK, dcumset, cumset + + def findHCut(self, ndPage, loBaseline, dsetTableByRow, fCutHeight=25, iVerbose=0): + """ + find "horizontal" cuts that define a unique partition of the page text + return a list of LineString + """ + traceln(" - cut are made %s the text baseline centroid" % ("above" if self.bCutAbove else "below")) + + # GT: row -> set of object index + bGT = len(dsetTableByRow) > 0 + + traceln(" - Minimal cut height=", fCutHeight) + w = int(ndPage.get("imageWidth")) + + dlCut_by_Partition = defaultdict(list) # dict partition -> list of cut lines + partition = PolygonPartition(loBaseline) + + _partitionFun = partition.getObjectBelowLineByIds if self.bCutAbove else partition.getObjectAboveLineByIds + + # Now consider in turn each baseline as the pivot for creating a separator + # below it + for oBaseline in loBaseline: + + # for each candidate "skewed" cuts + for angle in self.lAngle: + oCut = self.getTangentLineStringAtAngle(angle, oBaseline, w) + + if partition.isValidRibbonCut(oCut, -fCutHeight if self.bCutAbove else fCutHeight): + #ok, store this candidate cut and associated partition! + tIds = _partitionFun(oCut) + dlCut_by_Partition[tIds].append(oCut) + oCut._du_support = oBaseline + oCut._du_angle = angle + oCut._du_label = "O" # just put "O" for wxvisu to show things + if bGT: oCut.__du_tIds = tIds # temporarily + + lloCutByPartition = list(dlCut_by_Partition.values()) + cntCut = sum(len(v) for v in lloCutByPartition) + traceln(" - found %d \"horizontal\" cuts" % cntCut) + + # keep one cut per partition + cntByAngle = Counter(o._du_angle for lo in lloCutByPartition for o in lo) + if True: + # preferring the closest to the average angle + try: + avgAngle = sum(k*v for k,v in cntByAngle.items()) / sum(cntByAngle.values()) + except ZeroDivisionError: + avgAngle = 0 + lambdaScore = lambda o: - abs(o._du_angle - avgAngle) + else: + # preferring the most frequent angle + lambdaScore = lambda o: cntByAngle[o._du_angle] + cntCountByAngleDeg = {math.degrees(k):v for k,v in cntByAngle.items()} + lDeg = sorted((cntCountByAngleDeg.keys()) + , reverse=True, key=lambda o: cntCountByAngleDeg[o]) + traceln(" - Observed skew angles: " + "|".join([ + " %.2f : %d "%(d, cntCountByAngleDeg[d]) for d in lDeg + ])) + + # Code below is correct but we need to do some more things for having better features + # loCut = sorted((max(lo, key=lambda o: cntByAngle[o._du_angle]) + # for lo in lloCutByPartition) + # , key=lambda o: o.centroid.y) + loNewCut = [] + for _loCut in lloCutByPartition: + # most frequent angle given the partition + oCutBest = max(_loCut, key=lambdaScore) + # create _du_set_support containing the set of node that lead to the same partition + # set of nodes that generated this particular partition + oCutBest._du_set_support = set(_o._du_support._du_index for _o in _loCut) + # frequency of the cut's angle over the page + oCutBest._du_angle_freq = cntByAngle[oCutBest._du_angle] / cntCut + # cumulative frequency of all cuts that are represented by the chosen one + oCutBest._du_angle_cumfreq = sum(cntByAngle[_o._du_angle] for _o in _loCut) / cntCut + loNewCut.append(oCutBest) + loCut = sorted(loNewCut, key=lambda o: o.centroid.y) + + traceln(" - kept %d \"horizontal\" unique cuts" % len(loCut)) + if loCut: + traceln(" - average count of support nodes per cut : %.3f" % + (sum(len(_o._du_set_support) for _o in loCut) / len(loCut))) + + if bGT: + traceln(" - loading GT Cell information") + # make a dictionary of cumulative sets, and the set of all objects + lTableRowK, dcumset, cumset = self.makeCumulativeTableByRow(dsetTableByRow, self.bCutAbove) + # to tag at best the last cuts determining a valid partition... + bestLastlSep = [] + bestLastLen = 99999999 + traceln("\tfound %d objects in table" % len(cumset)) + + dGTCoverage = { k:0 for k in lTableRowK } + # OK, let's tag with S I O based on the partition created by each cut + for oCut in loCut: + # build the set of index of text in the table, + # and above the cut + setIdx = cumset.intersection(set(oCut.__du_tIds)) + +# if oCut.centroid.y == 1177: +# print(" oCut ", list(oCut.coords)) +# print(sorted(list(setIdx))) +# print(setIdx.difference(dcumset[4])) +# print(dcumset[4].difference(setIdx)) +# print(list(loBaseline[83].coords)) +# print(loBaseline[83]._du_nd.getparent().get("id")) +# lkjljl + + # print(oCut._du_index, "nb table object above=", len(setIdx)) + #if setIdx in dcumset.values(): # a valid partition (compatible with the table) + bNotFound = True + for k in lTableRowK: + if setIdx == dcumset[k]: # a valid partition (compatible with the table) + bNotFound = False + dGTCoverage[k] += 1 + # ok, that partition was found + + if setIdx == cumset: + # is it the last separator of the table, or above the last?? + # is it the last separator of the table, or below the last?? + if len(tIds) <= bestLastLen: # better end of table, because less O + if len(tIds) == bestLastLen: # same, in fact + bestLastlSep.append(oCut) + else: + bestLastlSep = [oCut] + bestLastLen = len(setIdx) + label = "O" # we fix some of them at the end + else: + # ok this is a valid table partition + label = "S" + if bNotFound: + if len(setIdx) > 0: + label = "I" # some table elements above, but not all + else: + label = "O" + oCut._du_label = label + + del oCut.__du_tIds + for oCut in bestLastlSep: oCut._du_label = "S" + + c = Counter(oCut._du_label for oCut in loCut) + lk = sorted(c.keys()) + traceln("GT: > CUT Label count: ", " ".join("%s:%d"%(k, c[k]) for k in lk)) + + n = sum(dGTCoverage.values()) + nOk = len([k for k,v in dGTCoverage.items() if v > 0]) + nMiss = len([k for k,v in dGTCoverage.items() if v == 0]) + if nMiss > 0: + for k,v in dGTCoverage.items(): + if v == 0: traceln("missed k=%d"%k) + self.gtStatReport((n, nOk, nMiss)) + self.gtStatAdd(n, nOk, nMiss, len(loCut)) + self.gtStatReport() + + return loCut + + def getTangentLineStringAtAngle(self, a, o, w): + """ + Find the line with given angle (less than pi/2 in absolute value) that is immediately below the object + (angle in radians) + return a Line + """ + return geom.LineString( self.getTangentAtAngle(a, o, w) ) + + def getTangentAtAngle(self, a, o, w): + """ + Find the line with given angle (less than pi/2 in absolute value) that is immediately below the object + (angle in radians) + return a Line + """ + EPSILON = 1 + + minx, miny, maxx, maxy = o.bounds + + # first a line with this slope at some distance from object + xo = (minx + maxx) // 2 + if self.bCutAbove: + yo = miny - (maxx - minx) - 100 + else: + yo = maxy + (maxx - minx) + 100 + y0, yw = self._getTangentAlongXCoord(a, xo, yo, w) + oLine = geom.LineString([(0,y0), (w, yw)]) + + #nearest points + pt1, _pt2 = shapely.ops.nearest_points(o, oLine) + x,y = pt1.x, pt1.y + y0, yw = self._getTangentAlongXCoord(a, x, y, w) + + if self.bCutAbove: + return (0, math.floor(y0-EPSILON)), (w, math.floor(yw-EPSILON)) + else: + return (0, math.ceil(y0+EPSILON)), (w, math.ceil(yw+EPSILON)) + + def _getTangentAlongXCoord(self, a, x, y, w): + """ + intersection of the line, with angle a, at x=0 and x=w + return y0, yw + """ + if abs(a) <= 0.001: + # math.radians(0.1) -> 0.0017453292519943296 + # this is horizontal! + y0 = y + yw = y + else: + t = math.tan(a) + y0 = y - x * t + yw = y + (w - x) * t + return y0, yw + + def remove_cuts_from_dom(self, root): + """ + clean the DOM from any existing cut + return the number of removed cut lines + """ + lnd = MultiPageXml.getChildByName(root,'CutSeparator') + n = len(lnd) + for nd in lnd: + nd.getparent().remove(nd) + return n + + def add_Hcut_to_Page(self, ndPage, loCut, domid=0): + """ + Add the cut to the page as a CutSeparator + """ + for oCut in loCut: + domid += 1 + self.addPageXmlSeparator(ndPage, oCut, domid) + + return domid + + @classmethod + def addPageXmlSeparator(cls, ndPage, oCut, domid): + ndSep = MultiPageXml.createPageXmlNode("CutSeparator") + # propagate the groundtruth info we have + ndSep.set("DU_type" , oCut._du_label) + ndSep.set("orient" , "0") + ndSep.set("DU_angle" , "%.1f"%math.degrees(oCut._du_angle)) + ndSep.set("DU_angle_freq" , "%.3f"%oCut._du_angle_freq) + ndSep.set("DU_angle_cumul_freq" , "%.3f"%oCut._du_angle_cumfreq) + ndSep.set("DU_set_support" , "%s" %oCut._du_set_support) + ndSep.set("id" , "cs_%d" % domid) + ndPage.append(ndSep) + ndCoord = MultiPageXml.createPageXmlNode("Coords") + MultiPageXml.setPoints(ndCoord, oCut.coords) + ndSep.append(ndCoord) + return ndSep + + +class NoSeparatorException(Exception): + pass + + + +# ------------------------------------------------------------------ +def test__getTangentCoord(capsys): + + b1 = geom.Polygon([(1,2), (2,2), (2,3), (1,3)]) + + doer = SkewedCutAnnotator(bCutAbove=True) + + def printAngle(a, oLine): + [(xa,ya), (xb,yb)] = list(oLine.coords) + aa = math.atan((yb-ya) / (xb-xa)) + print("asked %.2f° got %.2f° (diff=%.4f°)" % (math.degrees(a), math.degrees(aa), math.degrees(aa-a))) + + with capsys.disabled(): + p11 = geom.Point((1, 1)) + oLine = doer.getTangentLineStringAtAngle(0, p11, 100) + assert list(oLine.coords) == [(0, 2.0), (100, 2.0)] + + oLine = doer.getTangentLineStringAtAngle(-0.1*math.pi/2, p11, 100) + assert oLine.distance(p11) > 0 + + oLine = doer.getTangentLineStringAtAngle(0.2*math.pi/2, p11, 100) + assert oLine.distance(p11) > 0 + + + oLine = doer.getTangentLineStringAtAngle(0, b1, 10) + assert list(oLine.coords) == [(0, 4), (10, 4)] + + a = -0.1*math.pi/2 + oLine = doer.getTangentLineStringAtAngle(a, b1, 5000) #typical page width + print() + printAngle(a, oLine) + #print(oLine) + assert oLine.distance(b1) > 0 + + a = 0.2*math.pi/2 + oLine = doer.getTangentLineStringAtAngle(a, b1, 5000) + printAngle(a, oLine) + #print(oLine) + assert oLine.distance(b1) > 0 + +def test__getTangentCoord_cut_above(capsys): + + b1 = geom.Polygon([(1,2), (2,2), (2,3), (1,3)]) + + doer = SkewedCutAnnotator(bCutAbove=True) + + def printAngle(a, oLine): + [(xa,ya), (xb,yb)] = list(oLine.coords) + aa = math.atan((yb-ya) / (xb-xa)) + print("asked %.2f° got %.2f° (diff=%.4f°)" % (math.degrees(a), math.degrees(aa), math.degrees(aa-a))) + + with capsys.disabled(): + p11 = geom.Point((1, 1)) + oLine = doer.getTangentLineStringAtAngle(0, p11, 100) + assert list(oLine.coords) == [(0, 0.0), (100, 0.0)] + + oLine = doer.getTangentLineStringAtAngle(-0.1*math.pi/2, p11, 100) + assert oLine.distance(p11) > 0 + + oLine = doer.getTangentLineStringAtAngle(0.2*math.pi/2, p11, 100) + assert oLine.distance(p11) > 0 + + + oLine = doer.getTangentLineStringAtAngle(0, b1, 10) + assert list(oLine.coords) == [(0, 1), (10, 1)] + + a = -0.1*math.pi/2 + oLine = doer.getTangentLineStringAtAngle(a, b1, 5000) #typical page width + print() + printAngle(a, oLine) + #print(oLine) + assert oLine.distance(b1) > 0 + + a = 0.2*math.pi/2 + oLine = doer.getTangentLineStringAtAngle(a, b1, 5000) + printAngle(a, oLine) + #print(oLine) + assert oLine.distance(b1) > 0 + + +# ------------------------------------------------------------------ +def op_cut(sFilename, sOutFilename, lDegAngle, bCutAbove, fMinHorizProjection=0.05, fCutHeight=25): + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + doer = SkewedCutAnnotator(bCutAbove, lAngle = [math.radians(x) for x in lDegAngle]) + + pnum = 0 + domid = 0 + for ndPage in MultiPageXml.getChildByName(root, 'Page'): + pnum += 1 + traceln(" --- page %s - constructing separator candidates" % pnum) + + #load the page objects and the GT partition (defined by the table) if any + loBaseline, dsetTableByRow = doer.loadPage(ndPage) + traceln(" - found %d objects on page" % (len(loBaseline))) + + # find almost-horizontal cuts and tag them if GT is available + loHCut = doer.findHCut(ndPage, loBaseline, dsetTableByRow, fCutHeight) + + #create DOM node reflecting the cuts + #first clean (just in case!) + n = doer.remove_cuts_from_dom(ndPage) + if n > 0: + traceln(" - removed %d pre-existing cut lines" % n) + + # if GT, then we have labelled cut lines in DOM + domid = doer.add_Hcut_to_Page(ndPage, loHCut, domid) + + doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + print('Annotated cut separators added to %s'%sOutFilename) + + + +def computePRF(nOk, nErr, nMiss): + eps = 0.00001 + fP = 100 * nOk / (nOk + nErr + eps) + fR = 100 * nOk / (nOk + nMiss + eps) + fF = 2 * fP * fR / (fP + fR + eps) + return fP, fR, fF + +def _debugPartition(s, lset): + traceln("---- ", s) + for s in lset: + traceln(s) + +def _isBaselineNotO(nd): + """ + filter Baseline tagged as 'O' or not tagged at all + """ + v = nd.getparent().get("DU_row") + return v is None or v not in ["O"] + +def _isBaselineInTable(nd): + """ + a Baseline in a TableRegion belongs to a TableCell element + """ + v = nd.getparent().getparent().get("row") + return not(v is None) + +def get_row_partition(doer, sxpCut, dNS + , sFilename, lFilterFun + , bCutAbove=True + , bVerbose=False + , funIndex=lambda x: x._du_index + , bIgnoreHeader=False + ): + """ + return the GT partition in rows, as well as 1 partition per filter fucntion + """ + # load objects: Baseline and Cuts + if bVerbose: traceln("- loading %s"%sFilename) + parser = etree.XMLParser() + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + llsetRun = [] + + pnum = 0 + lndPage = MultiPageXml.getChildByName(root, 'Page') + assert len(lndPage) == 1, "NOT SUPPORTED: file has many pages - soorry" + for ndPage in lndPage: + pnum += 1 + if bVerbose: traceln(" - page %s - loading table GT" % pnum) + loBaseline, dsetTableByRow = doer.loadPage(ndPage, funIndex=funIndex + , bIgnoreHeader=bIgnoreHeader) + if bVerbose: traceln(" - found %d objects on page" % (len(loBaseline))) + + # make a dictionary of cumulative sets, and the set of all objects + lTableRowK = sorted(dsetTableByRow.keys()) + if bVerbose: + traceln(" - found %d rows" % (len(lTableRowK))) + traceln(" - found %d objects in the table" % (sum(len(v) for v in dsetTableByRow.values()))) + lNdCut = ndPage.xpath(sxpCut, namespaces=dNS) + if bVerbose: + traceln(" - found %d 'S' cut" % (len(lNdCut))) + else: + traceln("- loaded %40s " % sFilename + , " %6d rows %6d 'S' cuts" % ( len(lTableRowK) + , len(lNdCut)) + , " %6d objects %6d table objects" % ( + len(loBaseline) + , sum(len(v) for v in dsetTableByRow.values()) + ) + ) + loCut = [] + for ndCut in lNdCut: + #now we need to infer the bounding box of that object + (x1, y1), (x2, y2) = PageXml.getPointList(ndCut) #the polygon + # Create the shapely shape + loCut.append(geom.LineString([(x1, y1), (x2, y2)])) + + w,h = float(ndPage.get("imageWidth")), float(ndPage.get("imageHeight")) + # Add a fictive cut at top of page + loCut.append(geom.LineString([(0, 0), (w, 0)])) + # Add a fictive cut at end of page + loCut.append(geom.LineString([(0, h), (w, h)])) + + # order it by line centroid Y + loCut.sort(key=lambda o: o.centroid.y) + + # dcumset is the GT!! + lsetGT = [dsetTableByRow[k] for k in lTableRowK] # list of set of du_index + + # NOW, look at predictions + for filterFun in lFilterFun: + + loBaselineInTable = [o for o in loBaseline if filterFun(o._du_nd)] + if bVerbose: traceln(" - %d objects on page predicted in table (%d out)" % ( + len(loBaselineInTable) + , len(loBaseline) - len(loBaselineInTable))) + + # Now create the list of partitions created by the Cuts + lsetRun = [] + partition = PolygonPartition(loBaselineInTable) + if bCutAbove: + #cut if above the text that led to its creation + setAllPrevIds = set([]) # cumulative set of what was already taken + for oCut in loCut: + lo = partition.getObjectBelowLine(oCut) + setIds = set(funIndex(o) for o in lo) + if setAllPrevIds: + prevRowIds = setAllPrevIds.difference(setIds) # content of previous row + if prevRowIds: + #an empty set is denoting alternative cuts leading to same partition + lsetRun.append(prevRowIds) + setAllPrevIds = setIds + else: + #cut if below the text that led to its creation + cumSetIds = set([]) # cumulative set + for oCut in loCut: + lo = partition.getObjectAboveLine(oCut) + setIds = set(funIndex(o) for o in lo) + rowIds = setIds.difference(cumSetIds) # only last row! + if rowIds: + #an empty set is denoting alternative cuts leading to same partition + lsetRun.append(rowIds) + cumSetIds = setIds +# _debugPartition("run", lsetRun) + llsetRun.append(lsetRun) +# _debugPartition("ref", lsetGT) + return lsetGT, llsetRun + +def op_eval_row(lsFilename, fSimil, bCutAbove, bVerbose=False + , bIgnoreHeader=False): + """ + We load the XML + - get the cut with @DU_type="S" + - get the text objects (geometry=Baseline) + - + """ + nOk, nErr, nMiss = 0, 0, 0 + + if fSimil is None: + #lfSimil = [ i / 100 for i in range(75, 101, 5)] + lfSimil = [ i / 100 for i in range(70, 101, 10)] + else: + lfSimil = [fSimil] + + # we use only BIO+SIO + dOkErrMissOnlyRow = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissOnlyRow.update({'name':'OnlyRow' + , 'FilterFun':_isBaselineNotO}) + # we use the SIO and the TableRegion + dOkErrMissTableRow = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissTableRow.update({'name':'TableRow' + , 'FilterFun':_isBaselineInTable}) + ldOkErrMiss = [dOkErrMissOnlyRow, dOkErrMissTableRow] + + sxpCut = './/pc:CutSeparator[@orient="0" and @DU_type="S"]' #how to find the cuts + dNS = {"pc":PageXml.NS_PAGE_XML} + + doer = SkewedCutAnnotator(bCutAbove) + + traceln(" - Cut selector = ", sxpCut) + + # load objects: Baseline and Cuts + for n, sFilename in enumerate(lsFilename): + lsetGT, llsetRun = get_row_partition(doer, sxpCut, dNS + , sFilename + , [dOkErrMiss['FilterFun'] for dOkErrMiss in ldOkErrMiss] + , bCutAbove=True, bVerbose=False + , funIndex=lambda o: o._dom_id + , bIgnoreHeader=bIgnoreHeader + ) + pnum = 1 # only support single-page file... + for dOkErrMiss, lsetRun in zip(ldOkErrMiss, llsetRun): + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions(lsetRun, lsetGT, fSimil, jaccard_distance) + if bVerbose: + traceln(" - - - simil = %.2f" % fSimil) + traceln("----- RUN ----- ") + for s in lsetRun: traceln(" run ", sorted(s)) + traceln("----- REF ----- ") + for s in lsetGT: traceln(" ref ", sorted(s)) + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + if bVerbose or fSimil == 1.0: + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + traceln("%4d %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d %s page=%d" %( + n+1, dOkErrMiss['name'], fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + , os.path.basename(sFilename), pnum)) + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + + for dOkErrMiss in [dOkErrMissOnlyRow, dOkErrMissTableRow]: + traceln() + name = dOkErrMiss['name'] + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f " % (name, fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + return (nOk, nErr, nMiss) + + +def op_eval_old(lsFilename, fSimil, bDetail=False): + """ + We load the XML + - get the cut with @type="S" + - get the text objects (geometry=Baseline) + - + """ + nOk, nErr, nMiss = 0, 0, 0 + + # OLD STYLE (May'18) + sxpCut = './/pc:CutSeparator[@orient="0" and @type="S"]' #how to find the cuts + + dNS = "./pc:TextEquiv" + + doer = SkewedCutAnnotator(True) + + traceln(" - Cut selector = ", sxpCut) + + def getPolylineAverageXY(ndPolyline): + """ + COPIED FROM tasks.DU_ABPTableCutAnnotator.BaselineCutAnnotator + weighted average X and average Y of a polyline + the weight indicate how long each segment at a given X, or Y, was. + """ + sPoints=ndPolyline.get('points') + lXY = Polygon.parsePoints(sPoints).lXY + + # list of X and Y values and respective weights + lXYWxWy = [((x1+x2)/2.0, abs(y2-y1), # for how long at this X? + (y1+y2)/2.0, abs(x2-x1)) \ + for (x1,y1), (x2, y2) in zip(lXY, lXY[1:])] + fWeightedSumX = sum(x*wx for x, wx, _, _ in lXYWxWy) + fWeightedSumY = sum(y*wy for _, _, y, wy in lXYWxWy) + fSumWeightX = sum( wx for _, wx , _, _ in lXYWxWy) + fSumWeightY = sum( wy for _, _ , _, wy in lXYWxWy) + + Xavg = int(round(fWeightedSumX/fSumWeightX)) if fSumWeightX > 0 else 0 + Yavg = int(round(fWeightedSumY/fSumWeightY)) if fSumWeightY > 0 else 0 + +# Xavg, Yavg = self.moduloSnap(Xavg, Yavg) + + return (Xavg, Yavg) + + def baseline_loader(nd): + """ + load the baseline as done in DU_ABPTableCutAnnotator + """ + x, y = getPolylineAverageXY(nd) + # make a short horizontal line out of a point + return geom.LineString([(x-10,y), (x+10, y)]) + + # load objects: Baseline and Cuts + for sFilename in lsFilename: + traceln("- loading %s"%sFilename) + parser = etree.XMLParser() + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + pnum = 0 + for ndPage in MultiPageXml.getChildByName(root, 'Page'): + pnum += 1 + traceln(" - page %s - loading table GT" % pnum) + loBaseline, dsetTableByRow = doer.loadPage(ndPage, shaper_fun=baseline_loader) + traceln(" - found %d objects on page" % (len(loBaseline))) + # make a dictionary of cumulative sets, and the set of all objects + lTableRowK = sorted(dsetTableByRow.keys(), reverse=True) # bottom to top + traceln(" - found %d objects in the table" % (sum(len(v) for v in dsetTableByRow.values()))) + + lNdCut = ndPage.xpath(sxpCut, namespaces={"pc":PageXml.NS_PAGE_XML}) + traceln(" - found %d 'S' cut" % (len(lNdCut))) + loCut = [] + for ndCut in lNdCut: + #now we need to infer the bounding box of that object + (x1, y1), (x2, y2) = PageXml.getPointList(ndCut) #the polygon + # make sure that the cut is above the baseline that created it + y1 -= 1 + y2 -= 1 + assert y1 == y2 # in this version, the cuts were horizontal + # Create the shapely shape + loCut.append(geom.LineString([(x1, y1), (x2, y2)])) + # order it by line centroid Y + loCut.sort(key=lambda o: o.centroid.y, reverse=True) # from bottom to top + + # dcumset is the GT!! + lsetGT = [dsetTableByRow[k] for k in lTableRowK] # list of set of du_index + + # Now create the list of partitions created by the Cuts, excluding the 'O' + lsetRun = [] + partition = PolygonPartition(loBaseline) + cumSetIds = set([]) # cumulative set + for oCut in loCut: + lo = partition.getObjectBelowLine(oCut) + setIds = set(o._du_index for o in lo if _isBaselineInTable(o._du_nd)) + rowIds = setIds.difference(cumSetIds) # only last row! + if rowIds: + #an empty set is denoting alternative cuts leading to same partition + lsetRun.append(rowIds) + cumSetIds = setIds +# _debugPartition("run", lsetRun) +# _debugPartition("ref", lsetGT) + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions(lsetRun, lsetGT, fSimil, jaccard_distance) + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + if bDetail: + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + traceln("ok=%d err=%d miss=%d P=%.1f R=%.1f F1=%.1f %s page=%d" %( + _nOk, _nErr, _nMiss + , _fP, _fR, _fF + , sFilename, pnum)) + + fP, fR, fF = computePRF(nOk, nErr, nMiss) + + traceln("SUMMARY == P=%.1f%%\tR=%.1f%%\tF1=%.1f" % (fP, fR, fF )) + traceln("ok=%d err=%d miss=%d P=%.1f R=%.1f F1=%.1f" %( + nOk, nErr, nMiss + , fP, fR, fF)) + return (nOk, nErr, nMiss) + +# ------------------------------------------------------------------ +def op_gt_recall(lsFilename, bCutAbove, lDegAngle, fMinHorizProjection=0.05, fCutHeight=25): + cAll = Counter() + for sFilename in lsFilename: + traceln("- loading GT: %s"%sFilename) + + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + doer = SkewedCutAnnotator(bCutAbove, lAngle = [math.radians(x) for x in lDegAngle]) + + pnum = 0 + for ndPage in MultiPageXml.getChildByName(root, 'Page'): + pnum += 1 + traceln(" --- page %s - constructing separator candidates" % pnum) + + #load the page objects and the GT partition (defined by the table) if any + loBaseline, dsetTableByRow = doer.loadPage(ndPage) + traceln(" - found %d objects on page" % (len(loBaseline))) + + # find almost-horizontal cuts and tag them if GT is available + loHCut = doer.findHCut(ndPage, loBaseline, dsetTableByRow, fCutHeight) + cAll.update(Counter(o._du_label for o in loHCut)) + + lk = sorted(cAll.keys()) + traceln("GT: ALL CUT Label count: ", " ".join("%s:%d"%(k, cAll[k]) for k in lk)) + + +# ------------------------------------------------------------------ +if __name__ == "__main__": + usage = "" + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--height", dest="fCutHeight", default=10 + , action="store", type=float, help="Minimal height of a cut") + parser.add_option("--simil", dest="fSimil", default=None + , action="store", type=float, help="Minimal similarity for associating 2 partitions") + parser.add_option("--angle", dest='lsAngle' + , action="store", type="string", default="0" + ,help="Allowed cutting angles, in degree, comma-separated") + parser.add_option("--cut-below", dest='bCutBelow', action="store_true", default=False + ,help="Each object defines one or several cuts above it (instead of above as by default)") +# parser.add_option("--cut-above", dest='bCutAbove', action="store_true", default=None +# , help="Each object defines one or several cuts above it (instead of below as by default)") + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true", default=False) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + options.bCutAbove = not(options.bCutBelow) + + #load mpxml + op = args[0] + # -------------------------------------- + if op == "cut": + sFilename = args[1] + sOutFilename = args[2] + traceln("- cutting : %s --> %s" % (sFilename, sOutFilename)) + lDegAngle = [float(s) for s in options.lsAngle.split(",")] + traceln("- Allowed angles (°): %s" % lDegAngle) + op_cut(sFilename, sOutFilename, lDegAngle, options.bCutAbove, fCutHeight=options.fCutHeight) + # -------------------------------------- + elif op == "eval": + lsFilename = args[1:] + traceln("- evaluating cut-based partitions (fSimil=%s): " % options.fSimil, lsFilename) + op_eval_row(lsFilename, options.fSimil, options.bCutAbove, options.bVerbose) + # -------------------------------------- + elif op == "eval_bsln": + lsFilename = args[1:] + traceln("- evaluating baseline-based partitions : ", lsFilename) + op_eval_old(lsFilename, options.fSimil, True) + # -------------------------------------- + elif op == "gt_recall": + lsFilename = args[1:] + traceln("- GT recall : %s" % lsFilename) + lDegAngle = [float(s) for s in options.lsAngle.split(",")] + traceln("- Allowed angles (°): %s" % lDegAngle) + op_gt_recall(lsFilename, options.bCutAbove, lDegAngle, fCutHeight=options.fCutHeight) + else: + print("Usage: %s [cut|eval|eval_bsln|gt_eval]") + + + diff --git a/TranskribusDU/tasks/DU_Table/DU_ABPTable_T.py b/TranskribusDU/tasks/DU_Table/DU_ABPTable_T.py new file mode 100644 index 0000000..8d2fc3f --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_ABPTable_T.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- + +""" + Example DU task for ABP Table that uses the Multi-Type CRF + + Copyright Xerox(C) 2017 H. Déjean, JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks.DU_Task_Factory import DU_Task_Factory +from tasks import _checkFindColDir, _exit + +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from graph.NodeType_PageXml import NodeType_PageXml_type_woText +from tasks.DU_CRF_Task import DU_CRF_Task +from graph.FeatureDefinition_PageXml_std_noText_v3 import FeatureDefinition_T_PageXml_StandardOnes_noText_v3 + + + + + +class DU_ABPTable_TypedCRF(DU_CRF_Task): + """ + We will do a typed CRF model for a DU task + , with the below labels + """ + sXmlFilenamePattern = "*.mpxml" + + sLabeledXmlFilenamePattern = "*.mpxml" + + sLabeledXmlFilenameEXT = ".mpxml" + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + + # =============================================================================================================== + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiSinglePageXml + + lLabels1 = ['RB', 'RI', 'RE', 'RS','RO'] + lIgnoredLabels1 = None + # """ + # if you play with a toy collection, which does not have all expected classes, you can reduce those. + # """ + # + # lActuallySeen = None + # if lActuallySeen: + # print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING" + # lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + # lLabels = [lLabels[i] for i in lActuallySeen ] + # print len(lLabels) , lLabels + # print len(lIgnoredLabels) , lIgnoredLabels + # nbClass = len(lLabels) + 1 #because the ignored labels will become OTHER + + nt1 = NodeType_PageXml_type_woText("text" #some short prefix because labels below are prefixed with it + , lLabels1 + , lIgnoredLabels1 + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt1.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt1) + + nt2 = NodeType_PageXml_type_woText("sprtr" #some short prefix because labels below are prefixed with it + , ['SI', 'SO'] + , None + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt2.setXpathExpr( (".//pc:SeparatorRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text (no text in fact) + ) + DU_GRAPH.addNodeType(nt2) + + + #=== CONFIGURATION ==================================================================== + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + + #another way to specify the graph class + # defining a getConfiguredGraphClass is preferred + self.configureGraphClass(self.DU_GRAPH) + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 8 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment=sComment + , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText_v3 + , dFeatureConfig = { + #config for the extractor of nodes of each type + "text": None, + "sprtr": None, + #config for the extractor of edges of each type + "text_text": None, + "text_sprtr": None, + "sprtr_text": None, + "sprtr_sprtr": None + } + ) + + traceln("- classes: ", self.DU_GRAPH.getLabelNameList()) + + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + + #=== END OF CONFIGURATION ============================================================= + + + def predict(self, lsColDir,sDocId): + """ + Return the list of produced files + """ +# self.sXmlFilenamePattern = "*.a_mpxml" + return DU_CRF_Task.predict(self, lsColDir,sDocId) + + +# ---------------------------------------------------------------------------- +def main(sModelDir, sModelName, options): + doer = DU_ABPTable_TypedCRF(sModelName, sModelDir, + C = options.crf_C, + tol = options.crf_tol, + njobs = options.crf_njobs, + max_iter = options.max_iter, + inference_cache = options.crf_inference_cache) + + + if options.docid: + sDocId=options.docid + else: + sDocId=None + if options.rm: + doer.rm() + return + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] +# if options.bAnnotate: +# doer.annotateDocument(lTrn) +# traceln('annotation done') +# sys.exit(0) + + ## use. a_mpxml files + doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern + + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None) + import graph.GraphModel + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for GridSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + + if lRun: + doer.load() + lsOutputFilename = doer.predict(lRun,sDocId) + traceln("Done, see in:\n %s"%lsOutputFilename) + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0], version) + parser.add_option("--docid", dest='docid', action="store",default=None, help="only process docid") + # --- + #parse the command line + (options, args) = parser.parse_args() + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + main(sModelDir, sModelName, options) diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Annotator.py b/TranskribusDU/tasks/DU_Table/DU_Table_Annotator.py new file mode 100644 index 0000000..a699efe --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Annotator.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +""" + USAGE: DU_Table_Annotator.py input-folder + + You must run this on your GT collection to create a training collection. + + If you pass a folder, you get a new folder with name postfixed by a_ + + Does 2 things: + + - 1 - + Annotate textlines for Table understanding (finding rows and columns) + + It tags the TextLine, to indicate: + - the table header, vs data, vs other stuff: + @DU_header = 'CH' | 'D' | 'O' + + - the vertical rank in the table cell: + @DU_row = 'B' | 'I' | 'E' | 'S' | 'O' + + - something regarding the number of text in a cell?? + # NO SURE THIS WORKS... + @DU_col = 'M' | 'S' | 'O' + + + - 2 - + Aggregate the borders of the cells by linear regression to reflect them + as a line, which is stored as a SeparatorRegion element. + + + Copyright Naver Labs Europe 2017, 2018 + H. Déjean + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. +""" +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +import tasks.DU_Table.DU_ABPTableRCAnnotation + + +if __name__ == "__main__": + try: + #we expect a folder + sInputDir = sys.argv[1] + if not os.path.isdir(sInputDir): raise Exception() + except IndexError: + traceln("Usage: %s " % sys.argv[0]) + exit(1) + + sOutputDir = "a_"+sInputDir + traceln(" - Output will be in ", sOutputDir) + try: + os.mkdir(sOutputDir) + os.mkdir(os.path.join(sOutputDir, "col")) + except: + pass + + lsFilename = [s for s in os.listdir(os.path.join(sInputDir, "col")) if s.endswith(".mpxml") ] + lsFilename.sort() + lsOutFilename = [os.path.join(sOutputDir, "col", "a_"+s) for s in lsFilename] + if not lsFilename: + lsFilename = [s for s in os.listdir(os.path.join(sInputDir, "col")) if s.endswith(".pxml") ] + lsFilename.sort() + lsOutFilename = [os.path.join(sOutputDir, "col", "a_"+s[:-5]+".mpxml") for s in lsFilename] + + lsInFilename = [os.path.join(sInputDir , "col", s) for s in lsFilename] + + traceln(lsFilename) + traceln("%d files to be processed" % len(lsFilename)) + + tasks.DU_Table.DU_ABPTableRCAnnotation.main(lsInFilename, lsOutFilename) diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_CellBorder.py b/TranskribusDU/tasks/DU_Table/DU_Table_CellBorder.py new file mode 100644 index 0000000..5fd988e --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_CellBorder.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- + +""" + Table Undertsanding + + - given a human-annotated table + - find lines that reflect well the cell borders, for rows and columns + + This is done by linear interpolation of cell borders, bi row and by column. + + Copyright Naver Labs Europe 201 + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + + +import sys, os, math +import collections + +from lxml import etree +import matplotlib.pyplot as plt +import numpy as np +from scipy.optimize._hungarian import linear_sum_assignment + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import MultiPageXml , PageXml +from util.Polygon import Polygon +from util.partitionEvaluation import evalPartitions +from util.jaccard import jaccard_distance +from tasks.DU_Table.DU_ABPTableSkewed_CutAnnotator import SkewedCutAnnotator,\ + get_row_partition, _isBaselineInTable, computePRF +# from tasks.DU_ABPTableCutAnnotator import get_col_partition, CutAnnotator +import tasks.DU_Table.DU_ABPTableCutAnnotator + + + +class DocSeparatorException(Exception): + pass + + +def getDocSeparators(sFilename): + """ + return two dictionaries + row -> list of (x1, y1, x2, y2) + col -> list of (x1, y1, x2, y2) + """ + parser = etree.XMLParser() + doc = etree.parse(sFilename, parser) + root = doc.getroot() + lCell= MultiPageXml.getChildByName(root,'TableCell') + if not lCell: + raise DocSeparatorException("No TableCell element in %s" %sFilename) + dRowSep, dColSep = getCellsSeparators(lCell) + del doc + return dRowSep, dColSep + +def getCellsSeparators(lCell): + """ + return two dictionaries + row -> ((x1, y1), (x2, y2)) NOTE: top of row + col -> ((x1, y1), (x2, y2)) NOTE: left of column + """ + dRowSep = {} + dColSep = {} + + # let's collect the segments forming the cell borders, by row, by col + dRowSep_lSgmt = collections.defaultdict(list) + dColSep_lSgmt = collections.defaultdict(list) + for cell in lCell: + row, col, rowSpan, colSpan = [int(cell.get(sProp)) for sProp \ + in ["row", "col", "rowSpan", "colSpan"] ] + coord = cell.xpath("./a:%s" % ("Coords"),namespaces={"a":MultiPageXml.NS_PAGE_XML})[0] + sPoints = coord.get('points') + plgn = Polygon.parsePoints(sPoints) + try: + lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() + except ZeroDivisionError: + traceln("ERROR: cell %s row=%d col=%d has empty area and is IGNORED" + % (cell.get("id"), row, col)) + continue + #now the top segments contribute to row separator of index: row + dRowSep_lSgmt[row].extend(lT) + #now the bottom segments contribute to row separator of index: row+rowSpan + dRowSep_lSgmt[row+rowSpan].extend(lB) + + dColSep_lSgmt[col].extend(lL) + dColSep_lSgmt[col+colSpan].extend(lR) + + #now make linear regression to draw relevant separators + def getX(lSegment): + lX = list() + for x1,_y1,x2,_y2 in lSegment: + lX.append(x1) + lX.append(x2) + return lX + + def getY(lSegment): + lY = list() + for _x1,y1,_x2,y2 in lSegment: + lY.append(y1) + lY.append(y2) + return lY + + for row, lSegment in dRowSep_lSgmt.items(): + X = getX(lSegment) + Y = getY(lSegment) + #sum(l,()) + lfNorm = [math.sqrt(np.linalg.norm((x2 - x1, y2 - y1))) for x1,y1,x2,y2 in lSegment] + #duplicate each element + sumW = sum(lfNorm) * 2 + W = [fN/sumW for fN in lfNorm for _ in (0,1)] + # a * x + b + a, b = np.polynomial.polynomial.polyfit(X, Y, 1, w=W) + + xmin, xmax = min(X), max(X) + y1 = a + b * xmin + y2 = a + b * xmax + + dRowSep[row] = ((xmin, y1), (xmax, y2)) + + for col, lSegment in dColSep_lSgmt.items(): + X = getX(lSegment) + Y = getY(lSegment) + #sum(l,()) + lfNorm = [math.sqrt(np.linalg.norm((x2 - x1, y2 - y1))) for x1,y1,x2,y2 in lSegment] + #duplicate each element + sumW = sum(lfNorm) * 2 + W = [fN/sumW for fN in lfNorm for _ in (0,1)] + a, b = np.polynomial.polynomial.polyfit(Y, X, 1, w=W) + + ymin, ymax = min(Y), max(Y) + x1 = a + b * ymin + x2 = a + b * ymax + dColSep[col] = ((x1, ymin), (x2, ymax)) + + return dRowSep, dColSep + + +def op_eval_cell(sRowDir, sColDir, fRatio, bCutAbove=True, bVerbose=False): + """ + Takes output DU files from 2 folders + - one giving rows + - on giving columns + from the same input file!! + + Compute the quality partitioning in cells + + Show results + return (nOk, nErr, nMiss) + """ + #lsRowFn, lsColFn = [sorted([_fn for _fn in os.listdir(_sDir) if _fn.lower().endswith("pxml")]) for _sDir in (sRowDir, sColDir)] + lsRowFn = sorted([_fn for _fn in os.listdir(sRowDir) if _fn.lower().endswith("_du.mpxml")]) + lsColFn = sorted([_fn for _fn in os.listdir(sColDir) if _fn.lower().endswith(".mpxml")]) + + # checking coherence !! + lsRowBFn = [_fn[2:-9] for _fn in lsRowFn] # 'a_0001_S_Aldersbach_008-01_0064_du.mpxml' + lsColBFn = [_fn[4:-6] for _fn in lsColFn] # 'cut-0001_S_Aldersbach_008-01_0064.mpxml'" + #lsColBFn = [os.path.basename(_fn) for _fn in lsColFn] # 'cut-0001_S_Aldersbach_008-01_0064.mpxml'" + if lsRowBFn != lsColBFn: + setRowBFn = set(lsRowBFn) + setColBFn = set(lsColBFn) + traceln("WARNING: different filenames in each folder") + setOnlyCol = setColBFn.difference(setRowBFn) + if setOnlyCol: + traceln("--- Only in cols:", sorted(setOnlyCol), "\t") + traceln("--- Only in cols: %d files" % len(setOnlyCol), "\n\t", ) + setOnlyRow = setRowBFn.difference(setColBFn) + if setOnlyRow: + traceln("ERROR: different filenames in each folder") + traceln("--- Only in rows:", sorted(setOnlyRow), "\t") + traceln("--- Only in rows: %d files" % len(setOnlyRow), "\n\t", ) + sys.exit(1) + + if setOnlyCol: + # ok, let's clean the col list... :-/ + lsColFn2 = [] + ibfn = 0 + for fn in lsColFn: + if fn[4:-6] == lsRowBFn[ibfn]: + lsColFn2.append(fn) + ibfn += 1 + lsColFn = lsColFn2 + lsColBFn = [_fn[4:-6] for _fn in lsColFn] + assert lsRowBFn == lsColBFn + traceln("Reconciliated file lists... %d files now"%len(lsRowBFn)) + del setOnlyCol + +# lfSimil = [ i / 100 for i in range(70, 101, 10)] + lfSimil = [ i / 100.0 for i in [66, 80, 100]] + + dOkErrMissOnlyRow = { fSimil:(0,0,0) for fSimil in lfSimil } + dOkErrMissOnlyRow.update({'name':'OnlyCell' + , 'FilterFun':_isBaselineInTable}) + + dOkErrMiss = dOkErrMissOnlyRow + + def cross_row_col(lsetRow, lsetCol): + lsetCell = [] + for setRow in lsetRow: + for setCol in lsetCol: + setCell = setRow.intersection(setCol) + lsetCell.append(setCell) + return lsetCell + + def evalHungarian(lX,lY,th): + """ + """ + + cost_matrix=np.zeros((len(lX),len(lY)),dtype=float) + + for a,x in enumerate(lX): + for b,y in enumerate(lY): + #print(x,y, jaccard_distance(x,y)) + cost_matrix[a,b]= jaccard_distance(x,y) + + r1,r2 = linear_sum_assignment(cost_matrix) + ltobeDel=[] + for a,i in enumerate(r2): + # print (r1[a],ri) + if 1 - cost_matrix[r1[a],i] < th : + ltobeDel.append(a) + # if bt:print(lX[r1[a]],lY[i],1- cost_matrix[r1[a],i]) + # else: print(lX[i],lY[r1[a]],1-cost_matrix[r1[a],i]) + # else: + # if bt:print(lX[r1[a]],lY[i],1-cost_matrix[r1[a],i]) + # else:print(lX[i],lY[r1[a]],1-cost_matrix[r1[a],i]) + r2 = np.delete(r2,ltobeDel) + r1 = np.delete(r1,ltobeDel) + # print("wwww",len(lX),len(lY),len(r1),len(r2)) + + return len(r1), len(lX)-len(r1), len(lY)-len(r1) + + pnum = 1 # multi-page not supported + + for n, (sBasefilename, sFilename_row, sFilename_col) in enumerate(zip(lsRowBFn + , lsRowFn, lsColFn)): + assert sBasefilename in sFilename_row + assert sBasefilename in sFilename_col + dNS = {"pc":PageXml.NS_PAGE_XML} + if bVerbose: traceln("-"*30) + # Rows... + sxpCut = './/pc:CutSeparator[@orient="0" and @DU_type="S"]' #how to find the cuts + doer = SkewedCutAnnotator(bCutAbove) + #traceln(" - Cut selector = ", sxpCut) + + lsetGT_row, llsetRun_row = get_row_partition(doer, sxpCut, dNS + , os.path.join(sRowDir, sFilename_row) + , [_isBaselineInTable] + , bCutAbove=True, bVerbose=False + , funIndex=lambda o: o._dom_id + ) + #traceln("%d rows in GT"%len(lsetGT_row)) + + [lsetRun_row] = llsetRun_row # per page x filter function + if bVerbose: + for fSimil in lfSimil: + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions(lsetRun_row, lsetGT_row, fSimil, jaccard_distance) + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + traceln("%4d %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + n+1, "row", fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss)) + + # Columns... + sxpCut = './/pc:CutSeparator[@orient="90"]' #how to find the cuts + doer = tasks.DU_ABPTableCutAnnotator.CutAnnotator() + #traceln(" - Cut selector = ", sxpCut) + + # load objects: Baseline and Cuts + lsetGT_col, _lsetDataGT_col, llsetRun_col = tasks.DU_ABPTableCutAnnotator.get_col_partition(doer, sxpCut, dNS + , os.path.join(sColDir, sFilename_col) + , [_isBaselineInTable] + , fRatio + , bVerbose=False + , funIndex=lambda x: x._dom_id + ) + lsetGT_col = [set(_o) for _o in lsetGT_col] # make it a list of set + [lsetRun_col] = llsetRun_col # per page x filter function + lsetRun_col = [set(_o) for _o in lsetRun_col] # make it a list of set + if bVerbose: + for fSimil in lfSimil: + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions(lsetRun_col, lsetGT_col, fSimil, jaccard_distance) + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + traceln("%4d %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + n+1, "col", fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss)) + + lsetGT_cell = cross_row_col(lsetGT_row , lsetGT_col) + lsetRun_cell = cross_row_col(lsetRun_row, lsetRun_col) + + #traceln("%d %d"%(len(lsetGT_cell), len(lsetRun_cell))) + #lsetGT_cell = [_s for _s in lsetGT_cell if _s] + lsetRun_cell = [_s for _s in lsetRun_cell if _s] + #traceln("%d %d"%(len(lsetGT_cell), len(lsetRun_cell))) + + + # FIX + lsetGT_cell = list() + dNsSp = {"pc":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} + _parser = etree.XMLParser() + _doc = etree.parse(os.path.join(sColDir, sFilename_col), _parser) + for ndCell in _doc.getroot().xpath('//pc:TableCell', namespaces=dNsSp): + setCell = set(_nd.get("id") for _nd in ndCell.xpath('.//pc:TextLine', namespaces=dNsSp)) + if setCell: lsetGT_cell.append(setCell) + # traceln("%d non-empty cells in GT" % len(lsetGT_cell)) + + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + _nOk, _nErr, _nMiss = evalHungarian(lsetRun_cell, lsetGT_cell, fSimil) + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) +# if bVerbose: +# traceln(" - - - simil = %.2f" % fSimil) +# traceln("----- RUN ----- ") +# for s in lsetRun_cell: traceln(" run ", sorted(s)) +# traceln("----- REF ----- ") +# for s in lsetGT_cell: traceln(" ref ", sorted(s)) + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + traceln("%4d %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d %s page=%d" %( + n+1, dOkErrMiss['name'], fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + , sBasefilename, pnum)) + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + +# for dOkErrMiss in [dOkErrMissOnlyRow, dOkErrMissTableRow]: + + traceln() + name = dOkErrMiss['name'] + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL %8s simil:%.2f P %5.1f R %5.1f F1 %5.1f " % (name, fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + + return (nOk, nErr, nMiss) + + +# ------------------------------------------------------------------ +if __name__ == "__main__": + + # list the input files + lsFile = [] + for path in sys.argv: + if os.path.isfile(path): + lsFile.append(path) + elif os.path.isdir(path): + lsFilename = [os.path.join(path, "col", s) for s in os.listdir(os.path.join(path, "col")) if s.endswith(".mpxml") ] + if not lsFilename: + lsFilename = [os.path.join(path, "col", s) for s in os.listdir(os.path.join(path, "col")) if s.endswith(".pxml") and s[-7] in "0123456789"] + lsFilename.sort() + traceln(" folder %s --> %d files" % (path, len(lsFilename))) + lsFile.extend(lsFilename) + traceln("%d files to read" % len(lsFile)) + + traceln(lsFilename) + traceln("%d files to be processed" % len(lsFilename)) + + # load the separators (estimated by linear regression) + ldRowSep, ldColSep = [], [] + for sFilename in lsFilename: + try: + dRowSep, dColSep = getDocSeparators(sFilename) + except DocSeparatorException as e: + traceln("\t SKIPPING this file: " + str(e)) + ldRowSep.append(dRowSep) + ldColSep.append(dColSep) + + # Now look at the distribution of the separators' angles + # horizontally, and vertically + fig = plt.figure(1) + fig.canvas.set_window_title("Distribution of separator angles (Obtained by linear regression)") + for i, (bHorizontal, degavg, degmax, ldXYXY) in enumerate([ + (True , 0, 5, ldRowSep) + , (False, 90, 5, ldColSep) + ]): + C = collections.Counter() + + for dXYXY in ldXYXY: + lAngle = [] + for (x1,y1),(x2,y2) in dXYXY.values(): + if bHorizontal: + angle = math.degrees(math.atan((y2-y1) / (x2-x1))) + else: + angle = 90 - math.degrees(math.atan((x2-x1) / (y2 - y1))) + angle = round(angle, 1) + lAngle.append(angle) + C.update(lAngle) + + plt.subplot(211+i) + ltV = list(C.items()) + ltV.sort() + traceln(ltV) + lX = [tV[0] for tV in ltV if abs(tV[0]-degavg) <= degmax] + lY = [tV[1] for tV in ltV if abs(tV[0]-degavg) <= degmax] + if len(lX) < len(ltV): + traceln("WARNING: excluded %d bins (%d values in total) outside of [%.1f, %.1f]" + % (len(ltV)-len(lX) + , sum((tV[1] for tV in ltV if abs(tV[0]-degavg) > degmax)) + , degavg-degmax, degavg+degmax)) + #plt.plot(lX, lY) + plt.scatter(lX, lY) + plt.ylabel("Count") + plt.grid(which='both', axis='both') + plt.xticks(lX) + plt.xlabel("Degrees") + plt.show() + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Cell_Edge.py b/TranskribusDU/tasks/DU_Table/DU_Table_Cell_Edge.py new file mode 100644 index 0000000..06d25a0 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Cell_Edge.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- + +""" + DU task for segmenting text in cell, or col, or row using the conjugate + graph after the SW re-engineering by JLM during the 2019 summer. + + As of June 5th, 2015, this is the exemplary code + + Copyright NAVER(C) 2019 Jean-Luc Meunier + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) ))) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln + +from graph.NodeType_PageXml import defaultBBoxDeltaFun +from graph.NodeType_PageXml import NodeType_PageXml_type +from tasks.DU_Task_Factory import DU_Task_Factory +from tasks.DU_Task_Features import Features_June19_Simple +from tasks.DU_Task_Features import Features_June19_Simple_Separator +from tasks.DU_Task_Features import Features_June19_Simple_Shift +from tasks.DU_Task_Features import Features_June19_Simple_Separator_Shift +from tasks.DU_Task_Features import Features_June19_Full +from tasks.DU_Task_Features import Features_June19_Full_Separator +from tasks.DU_Task_Features import Features_June19_Full_Shift +from tasks.DU_Task_Features import Features_June19_Full_Separator_Shift + +from graph.pkg_GraphBinaryConjugateSegmenter.MultiSinglePageXml \ + import MultiSinglePageXml \ + as ConjugateSegmenterGraph_MultiSinglePageXml + +from graph.pkg_GraphBinaryConjugateSegmenter.MultiSinglePageXml_Separator \ + import MultiSinglePageXml_Separator \ + as ConjugateSegmenterGraph_MultiSinglePageXml_Separator + + +# ---------------------------------------------------------------------------- +# class My_ConjugateNodeType(NodeType_PageXml_type_woText): +class My_ConjugateNodeType(NodeType_PageXml_type): + """ + We need this to extract properly the label from the label attribute of the (parent) TableCell element. + """ + def __init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel=None, bOther=True, BBoxDeltaFun=defaultBBoxDeltaFun): + super(My_ConjugateNodeType, self).__init__(sNodeTypeName, lsLabel, lsIgnoredLabel, bOther, BBoxDeltaFun) + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + We rely on the standard self.sLabelAttr + raise a ValueError if the label is missing while bOther was not True + , or if the label is neither a valid one nor an ignored one + """ + domnode = graph_node.node + ndParent = domnode.getparent() + sLabel = "%s__%s" % ( ndParent.getparent().get("id") # TABLE ID ! + , ndParent.get(self.sLabelAttr) # e.g. "row" or "col" + ) + + return sLabel + + def setDocNodeLabel(self, graph_node, sLabel): + raise Exception("This should not occur in conjugate mode") + + +class My_ConjugateNodeType_Cell(My_ConjugateNodeType): + """ + For cells, the label is formed by the row __and__ col numberss + """ + def __init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel=None, bOther=True, BBoxDeltaFun=defaultBBoxDeltaFun): + super(My_ConjugateNodeType_Cell, self).__init__(sNodeTypeName, lsLabel, lsIgnoredLabel, bOther, BBoxDeltaFun) + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + domnode = graph_node.node + ndParent = domnode.getparent() + sLabel = "%s__%s__%s" % ( ndParent.getparent().get("id") # TABLE ID ! + , ndParent.get("row") + , ndParent.get("col") + ) + return sLabel + + +# ---------------------------------------------------------------------------- +def main(sys_argv_0, sLabelAttribute, cNodeType=My_ConjugateNodeType): + + + def getConfiguredGraphClass(_doer): + """ + In this class method, we must return a configured graph class + """ + # each graph reflects 1 page + if options.bSeparator: + DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml_Separator + else: + DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml + + ntClass = cNodeType + + nt = ntClass(sLabelAttribute #some short prefix because labels below are prefixed with it + , [] # in conjugate, we accept all labels, andNone becomes "none" + , [] + , False # unused + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt.setLabelAttribute(sLabelAttribute) + nt.setXpathExpr( (".//pc:TextLine" #how to find the nodes + #, "./pc:TextEquiv") #how to get their text + , ".//pc:Unicode") #how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + # standard command line options for CRF- ECN- GAT-based methods + usage, parser = DU_Task_Factory.getStandardOptionsParser(sys_argv_0) + parser.add_option("--separator", dest='bSeparator', action="store_true" + , default=False, help="Use the graphical spearators, if any, as edge features.") + parser.add_option("--text" , dest='bText' , action="store_true" + , default=False, help="Use textual information if any, as node and edge features.") + parser.add_option("--edge_vh", "--edge_hv" , dest='bShift' , action="store_true" + , default=False, help="Shift edge feature by range depending on edge type.") + parser.add_option("--jsonocr", dest='bJsonOcr', action="store_true" + , help="I/O is in json") + traceln("VERSION: %s" % DU_Task_Factory.getVersion()) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + DU_Task_Factory.exit(usage, 1, e) + if options.bText : traceln(" - using textual data, if any") + if options.bSeparator: traceln(" - using graphical separators, if any") + if options.bShift : traceln(" - shift edge features by edge type") + + if options.bText: + if options.bSeparator: + if options.bShift: + cFeatureDefinition = Features_June19_Full_Separator_Shift + else: + cFeatureDefinition = Features_June19_Full_Separator + else: + if options.bShift: + cFeatureDefinition = Features_June19_Full_Shift + else: + cFeatureDefinition = Features_June19_Full + else: + if options.bSeparator: + if options.bShift: + cFeatureDefinition = Features_June19_Simple_Separator_Shift + else: + cFeatureDefinition = Features_June19_Simple_Separator + else: + if options.bShift: + cFeatureDefinition = Features_June19_Simple_Shift + else: + cFeatureDefinition = Features_June19_Simple + + # === SETTING the graph type (and its node type) a,d the feature extraction pipe + doer = DU_Task_Factory.getDoer(sModelDir, sModelName + , options = options + , fun_getConfiguredGraphClass= getConfiguredGraphClass + , cFeatureDefinition = cFeatureDefinition + ) + + # == LEARNER CONFIGURATION === + # setting the learner configuration, in a standard way + # (from command line options, or from a JSON configuration file) + dLearnerConfig = doer.getStandardLearnerConfig(options) + + +# # force a balanced weighting +# print("Forcing balanced weights") +# dLearnerConfig['balanced'] = True + + # of course, you can put yours here instead. + doer.setLearnerConfiguration(dLearnerConfig) + + # === GO!! === + # act as per specified in the command line (--trn , --fold-run, ...) + doer.standardDo(options) + + del doer + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + main(sys.argv[0], "cell", My_ConjugateNodeType_Cell) diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Col_Cut.py b/TranskribusDU/tasks/DU_Table/DU_Table_Col_Cut.py new file mode 100644 index 0000000..00a54fb --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Col_Cut.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- + +""" + Create column clusters by projection profile + + Copyright Naver Labs Europe(C) 2019 JL Meunier +""" +import sys, os +from optparse import OptionParser + +from lxml import etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) ))) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln + +from util.Polygon import Polygon +from xml_formats.PageXml import MultiPageXml, PageXml +from graph.pkg_GraphBinaryConjugateSegmenter.GraphBinaryConjugateSegmenter import GraphBinaryConjugateSegmenter +from tasks.DU_Table.DU_ABPTableCutAnnotator import CutAnnotator + + +def main(lsFilename + , fRatio, fMinHLen + , fMinHorizProjection + , fMinVertiProjection=0.05 + ): + + for sFilename in lsFilename: + iDot = sFilename.rindex('.') + if sFilename[:iDot].endswith("_du"): continue + if True: + sOutFilename = sFilename[:iDot] + "_du.mpxml" # + "_du" + sFilename[iDot:] + else: + # to mimic the bug of DU_Task.predict until today... + sOutFilename = sFilename[:iDot-1] + "_du.mpxml" + traceln("- cutting: %s --> %s"%(sFilename, sOutFilename)) + + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + doer = CutAnnotator() + + # # Some grid line will be O or I simply because they are too short. + # fMinPageCoverage = 0.5 # minimum proportion of the page crossed by a grid line + # # we want to ignore col- and row- spans + #map the groundtruth table separators to our grid, per page (1 in tABP) + # ltlYlX = doer.get_separator_YX_from_DOM(root, fMinPageCoverage) + + # clean any previous cuts: + doer.remove_cuts_from_dom(root) + + # Find cuts and map them to GT + llY, llX = doer.add_cut_to_DOM(root + #, ltlYlX=ltlYlX + , fMinHorizProjection=fMinHorizProjection + , fMinVertiProjection=fMinVertiProjection + , fRatio=fRatio + , fMinHLen=fMinHLen) + + add_cluster_to_dom(root, llX) + + doc.write(sOutFilename, encoding='utf-8', pretty_print=True, xml_declaration=True) + traceln('Clusters and cut separators added into %s'%sOutFilename) + + del doc + + +def add_cluster_to_dom(root, llX): + """ + Cluster the Textline based on the vertical cuts + """ + + for lX, (_iPage, ndPage) in zip(llX, enumerate(MultiPageXml.getChildByName(root, 'Page'))): + w, _h = int(ndPage.get("imageWidth")), int(ndPage.get("imageHeight")) + + lX.append(w) + lX.sort() + # cluster of objects on + imax = len(lX) + dCluster = { i:list() for i in range(imax) } + + #Histogram of projections + lndTextline = MultiPageXml.getChildByName(ndPage, 'TextLine') + + # hack to use addClusterToDom + class MyBlock: + def __init__(self, nd): + self.node = nd + + o = GraphBinaryConjugateSegmenter() + o.lNode = [] + for nd in lndTextline: + o.lNode.append(MyBlock(nd)) + + for iNd, ndTextline in enumerate(lndTextline): + sPoints=MultiPageXml.getChildByName(ndTextline,'Coords')[0].get('points') + try: + x1,_y1,x2,_y2 = Polygon.parsePoints(sPoints).fitRectangle() + xm = (x1 + x2) / 2.0 + bLastColumn = True + for i, xi in enumerate(lX): + if xm <= xi: + dCluster[i].append(iNd) + ndTextline.set("DU_cluster", str(i)) + bLastColumn = False + break + if bLastColumn: + i = imax + dCluster[i].append(iNd) + ndTextline.set("DU_cluster", str(i)) + except ZeroDivisionError: + pass + except ValueError: + pass + + # add clusters + lNdCluster = o.addClusterToDom(dCluster, bMoveContent=False, sAlgo="cut", pageNode=ndPage) + + # add a cut_X attribute to the clusters + for ndCluster in lNdCluster: + i = int(ndCluster.get('name')) + ndCluster.set("cut_X", str(lX[i])) + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + usage = """+| +Generate _du.mpxml files. +""" + version = "v.01" + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--ratio", dest='fRatio', action="store" + , type=float + , help="Apply this ratio to the bounding box" + , default=0.66) + parser.add_option("--fMinHLen", dest='fMinHLen', action="store" + , type=float + , help="Do not scale horizontally a bounding box with width lower than this" + , default=75) + + parser.add_option("--fHorizRatio", dest='fMinHorizProjection', action="store" + , type=float + , help="On the horizontal projection profile, it ignores profile lower than this ratio of the page width" + , default=0.05) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + traceln(options) + + if args and all(map(os.path.isfile, args)): + lsFile = args + traceln("Working on files: ", lsFile) + elif args and os.path.isdir(args[0]): + sDir = args[0] + traceln("Working on folder: ", sDir) + lsFile = [os.path.join(sDir, s) for s in os.listdir(sDir) if s.lower().endswith("pxml")] + else: + traceln("Usage : %s " % sys.argv[0], usage) + sys.exit(1) + main(lsFile + , options.fRatio, fMinHLen=options.fMinHLen + , fMinHorizProjection=options.fMinHorizProjection + ) + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Col_Edge.py b/TranskribusDU/tasks/DU_Table/DU_Table_Col_Edge.py new file mode 100644 index 0000000..79ec588 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Col_Edge.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +""" + DU task for segmenting text in cols using the conjugate graph after the SW + re-engineering by JLM + + Copyright NAVER(C) 2019 Jean-Luc Meunier + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) ))) ) + import TranskribusDU_version +TranskribusDU_version + +from tasks.DU_Table.DU_Table_Cell_Edge import main + + +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + main(sys.argv[0], "col") + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_ECN.py b/TranskribusDU/tasks/DU_Table/DU_Table_ECN.py new file mode 100644 index 0000000..93b7310 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_ECN.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- + +""" + DU task for Table based on ECN + + Copyright NAVER(C) 2018, 2019 Hervé Déjean, Jean-Luc Meunier, Animesh Prasad + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from crf.Graph_MultiPageXml import Graph_MultiPageXml +from crf.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_woText, NodeType_PageXml_type +from crf.FeatureDefinition_PageXml_std import FeatureDefinition_PageXml_StandardOnes +import gcn.DU_Model_ECN +from tasks.DU_ECN_Task import DU_ECN_Task + + +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + +class DU_Table_ECN(DU_ECN_Task): + """ + ECN Models + """ + bHTR = False # do we have text from an HTR? + bPerPage = False # do we work per document or per page? + #bTextLine = False # if False then act as TextRegion + + sMetadata_Creator = "NLE Document Understanding ECN" + sXmlFilenamePattern = "*.mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.mpxml" #"*mpxml" + + + sLabeledXmlFilenameEXT = ".mpxml" + + dLearnerConfig = None + + #dLearnerConfig = {'nb_iter': 50, + # 'lr': 0.001, + # 'num_layers': 3, + # 'nconv_edge': 10, + # 'stack_convolutions': True, + # 'node_indim': -1, + # 'mu': 0.0, + # 'dropout_rate_edge': 0.0, + # 'dropout_rate_edge_feat': 0.0, + # 'dropout_rate_node': 0.0, + # 'ratio_train_val': 0.15, + # #'activation': tf.nn.tanh, Problem I can not serialize function HERE + # } + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + #lLabels = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + #lLabels = ['IGNORE', '577', '579', '581', '608', '32', '3431', '617', '3462', '3484', '615', '49', '3425', '73', '3', '3450', '2', '11', '70', '3451', '637', '77', '3447', '3476', '3467', '3494', '3493', '3461', '3434', '48', '3456', '35', '3482', '74', '3488', '3430', '17', '613', '625', '3427', '3498', '29', '3483', '3490', '362', '638a', '57', '616', '3492', '10', '630', '24', '3455', '3435', '8', '15', '3499', '27', '3478', '638b', '22', '3469', '3433', '3496', '624', '59', '622', '75', '640', '1', '19', '642', '16', '25', '3445', '3463', '3443', '3439', '3436', '3479', '71', '3473', '28', '39', '361', '65', '3497', '578', '72', '634', '3446', '627', '43', '62', '34', '620', '76', '23', '68', '631', '54', '3500', '3480', '37', '3440', '619', '44', '3466', '30', '3487', '45', '61', '3452', '3491', '623', '633', '53', '66', '67', '69', '643', '58', '632', '636', '7', '641', '51', '3489', '3471', '21', '36', '3468', '4', '576', '46', '63', '3457', '56', '3448', '3441', '618', '52', '3429', '3438', '610', '26', '609', '3444', '612', '3485', '3465', '41', '20', '3464', '3477', '3459', '621', '3432', '60', '3449', '626', '628', '614', '47', '3454', '38', '3428', '33', '12', '3426', '3442', '3472', '13', '639', '3470', '611', '6', '40', '14', '3486', '31', '3458', '3437', '3453', '55', '3424', '3481', '635', '64', '629', '3460', '50', '9', '18', '42', '3495', '5', '580'] + #lLabels = [ str(i) for i in range(0,5000)] + lLabels = [ str("%d_%d"%(t,i)) for t in range(0,10) for i in range(0,5000)] + lLabels.append('IGNORE') + # traceln (lLabels) + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + if cls.bPerPage: + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + else: + DU_GRAPH = Graph_MultiPageXml + + lActuallySeen = None + if lActuallySeen: + traceln("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + traceln(len(lLabels), lLabels) + traceln(len(lIgnoredLabels), lIgnoredLabels) + + if cls.bHTR: + ntClass = NodeType_PageXml_type + else: + #ignore text + ntClass = NodeType_PageXml_type_woText + + # DEFINING THE CLASS OF GRAPH WE USE + nt = ntClass("cell" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + + nt.setLabelAttribute("cell") + nt.setXpathExpr( (".//pc:TextLine" #how to find the nodes + #nt.setXpathExpr( (".//pc:TableCell//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") + ) + + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + traceln ( self.bHTR) + + if self.bHTR: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':300, 't_ngrams_node':(2,4), 'b_tfidf_node_lc':False + , 'n_tfidf_edge':300, 't_ngrams_edge':(2,4), 'b_tfidf_edge_lc':False } + else: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + # cFeatureDefinition = FeatureDefinition_PageXml_NoNodeFeat_v3 + dFeatureConfig = {} + + if sComment is None: sComment = sModelName + + if dLearnerConfigArg is not None and "ecn_ensemble" in dLearnerConfigArg: + traceln('ECN_ENSEMBLE') + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , dLearnerConfig=self.dLearnerConfig if dLearnerConfigArg is None else dLearnerConfigArg + , sComment=sComment + , cFeatureDefinition= cFeatureDefinition + , cModelClass=gcn.DU_Model_ECN.DU_Ensemble_ECN + ) + else: + #Default Case Single Model + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , dLearnerConfig=self.dLearnerConfig if dLearnerConfigArg is None else dLearnerConfigArg + , sComment= sComment + , cFeatureDefinition=cFeatureDefinition + ) + + #if options.bBaseline: + # self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_ECN_Task.predict(self, lsColDir) + + + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Evaluator.py b/TranskribusDU/tasks/DU_Table/DU_Table_Evaluator.py new file mode 100644 index 0000000..952f4b6 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Evaluator.py @@ -0,0 +1,704 @@ +# -*- coding: utf-8 -*- + +""" + Find cuts of a page along different slopes + and annotate them based on the table row content (which defines a partition) + + Copyright Naver Labs Europe 2018 + JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import sys, os +from optparse import OptionParser +from lxml import etree +from collections import defaultdict + +import numpy as np + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) )) + import TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import PageXml +from tasks.DU_Table.DU_ABPTableSkewed_CutAnnotator import op_cut, op_eval_row, op_gt_recall +from tasks.DU_Table.DU_ABPTableCutAnnotator import CutAnnotator, op_eval_col +from tasks.DU_Table.DU_Table_CellBorder import op_eval_cell +from util.partitionEvaluation import evalPartitions +from util.jaccard import jaccard_distance + +from util.hungarian import evalHungarian +if True: + # We fix a bug in this way, to keep old code ready to use + evalPartitions = evalHungarian + +from graph.NodeType_PageXml import defaultBBoxDeltaFun +from graph.NodeType_PageXml import NodeType_PageXml_type_woText + +from graph.pkg_GraphBinaryConjugateSegmenter.MultiSinglePageXml \ + import MultiSinglePageXml \ + as ConjugateSegmenterGraph_MultiSinglePageXml + +from graph.pkg_GraphBinaryConjugateSegmenter.MultiSinglePageXml_Separator \ + import MultiSinglePageXml_Separator \ + as ConjugateSegmenterGraph_MultiSinglePageXml_Separator + +def listFiles(sDir,ext="_du.mpxml"): + """ + return 1 list of files + """ + lsFile = sorted([_fn + for _fn in os.listdir(sDir) + if _fn.lower().endswith(ext) or _fn.lower().endswith(ext) + ]) + return lsFile + + + +def listParallelFiles(lsDocDir): + """ + return 1 list of file per folder, as a tuple + + Make sure the filenames correspond in each folder + """ + llsFile = [listFiles(sDir) for sDir in lsDocDir] + + if len(lsDocDir) > 1: + # correspondance tests + lset = [set(l) for l in llsFile] + setInter = lset[0].intersection(*lset[1:]) + setUnion = lset[0].union(*lset[1:]) + if setInter != setUnion: + for setFile, sDir in zip(lset, lsDocDir): + setExtra = setFile.difference(setInter) + if len(setExtra) > 0: + traceln("\t %s has %d extra files: %s" % (sDir, len(setExtra), sorted(list(setExtra)))) + else: + traceln("\t %s is OK") + raise Exception("Folders contain different filenames") + + return llsFile + + +def computePRF(nOk, nErr, nMiss): + eps = 0.00001 + fP = 100 * nOk / (nOk + nErr + eps) + fR = 100 * nOk / (nOk + nMiss + eps) + fF = 2 * fP * fR / (fP + fR + eps) + return fP, fR, fF + + +class My_ConjugateNodeType(NodeType_PageXml_type_woText): +# class My_ConjugateNodeType(NodeType_PageXml_type): + """ + We need this to extract properly the label from the label attribute of the (parent) TableCell element. + """ + def __init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel=None, bOther=True, BBoxDeltaFun=defaultBBoxDeltaFun): + super(My_ConjugateNodeType, self).__init__(sNodeTypeName, lsLabel, lsIgnoredLabel, bOther, BBoxDeltaFun) + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + domnode = graph_node.node + ndParent = domnode.getparent() + sLabel = "%s__%s" % ( ndParent.getparent().get("id") # TABLE ID ! + , self.sLabelAttr # e.g. "row" or "col" + ) + return sLabel + + def setDocNodeLabel(self, graph_node, sLabel): + raise Exception("This shoud not occur in conjugate mode") + + +def getConfiguredGraphClass(bSeparator=False): + """ + In this class method, we must return a configured graph class + """ + # each graph reflects 1 page + if bSeparator: + DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml_Separator + else: + DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml + + # ntClass = NodeType_PageXml_type + ntClass = My_ConjugateNodeType + + nt = ntClass("row" #some short prefix because labels below are prefixed with it + , [] # in conjugate, we accept all labels, andNone becomes "none" + , [] + , False # unused + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way +# , BBoxDeltaFun= None + + ) + nt.setLabelAttribute("idontcare") + nt.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , ".//pc:Unicode") #how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + +def labelEdges(g,sClusterLevel): + """ + g: DU_graph + label g edges as 0 (continue) or 1 (break) according to GT structure + + """ + Y = np.zeros((len(g.lEdge),2)) + for i,edge in enumerate(g.lEdge): + a, b = edge.A.node, edge.B.node + if sClusterLevel == "cell": + if (a.getparent().get("row"), a.getparent().get("col")) == (b.getparent().get("row"), b.getparent().get("col")): + a.set('DU_cluster', "%s__%s" % (a.get("row"), a.get("col"))) + b.set('DU_cluster', "%s__%s" % (a.get("row"), a.get("col"))) + Y[i][0]=1 + else: Y[i][1]=1 + elif sClusterLevel == "col": + if a.get("col") == b.get("col"): + a.set('DU_cluster', "%s" % (a.get("col"))) + b.set('DU_cluster', "%s" % (a.get("col"))) + Y[i][0]=1 + else: Y[i][1]=1 + elif sClusterLevel == "row": + if a.getparent().get("row") == b.getparent().get("row") and a.getparent().get("row") is not None: + tablea = a.getparent().getparent().get('id') + tableb = b.getparent().getparent().get('id') + a.set('DU_cluster', "%s_%s" % (a.getparent().get("row"),"_" + tablea)) + b.set('DU_cluster', "%s_%s" % (a.getparent().get("row"),"_" + tableb)) + if tablea == tableb: + Y[i][0]=1 + else: + Y[i][1]=1 + else: Y[i][1]=1 + else: + raise Exception("Unknown clustering level: %s"%sClusterLevel) + + return Y + +def eval_oracle(lsRunDir, sClusterLevel + , bIgnoreHeader=True + , bIgnoreOutOfTable=True + , lfSimil=[i / 100.0 for i in [66, 80, 100]] + , xpSelector=".//pc:TextLine"): + """ + evaluate the cluster quality from a run folder + + We assume to have the groundtruth row and col in the files as well as the predicted clusters + """ + assert lsRunDir + dOkErrMiss = { fSimil:(0,0,0) for fSimil in lfSimil } + + DU_GraphClass = getConfiguredGraphClass() + + for sRunDir in lsRunDir: + lsFile = listFiles(sRunDir,ext='.pxml') + traceln("-loaded %d files from %s" % (len(lsFile), sRunDir)) + + for sFilename in lsFile: + + # + lg = DU_GraphClass.loadGraphs(DU_GraphClass, [os.path.join(sRunDir, sFilename)], bDetach=False, bLabelled=False, iVerbose=1) + + # cluster -> [node_id] + dGT = defaultdict(list) + dRun = defaultdict(list) + +# doc = etree.parse(os.path.join(sRunDir, sFilename)) + # assume 1 page per doc! + g=lg[0] + rootNd = g.doc.getroot() + #assert len(PageXml.xpath(rootNd, "//pc:Page")) == 1, "NOT YET IMPLEMENTED: eval on multi-page files" + for iPage, ndPage in enumerate(PageXml.xpath(rootNd, "//pc:Page")): + traceln("PAGE %5d OF FILE %s" % (iPage+1, sFilename)) + + try:g = lg[iPage] + except IndexError:continue + Y = labelEdges(g,sClusterLevel) + g.form_cluster(Y) + g.addEdgeToDoc(Y) + + for nd in PageXml.xpath(ndPage, xpSelector): + if bIgnoreHeader and nd.getparent().get("custom") and "table-header" in nd.getparent().get("custom"): continue +# if bIgnoreHeader and nd.get("DU_header") != "D": continue + + ndparent = nd.getparent() + ndid = nd.get("id") + + if sClusterLevel == "cell": + val_gt = "%s__%s" % (ndparent.get("row"), ndparent.get("col")) + if val_gt == 'None__None' and bIgnoreOutOfTable: continue + elif sClusterLevel == "col": + val_gt = ndparent.get("col") + if val_gt == None and bIgnoreOutOfTable: continue + elif sClusterLevel == "row": + val_gt = ndparent.get("row") + if val_gt == None and bIgnoreOutOfTable: continue + else: + raise Exception("Unknown clustering level: %s"%sClusterLevel) + + # distinguish each table! + val_gt = val_gt + "_" + ndparent.getparent().get("id") + + dGT[val_gt].append(ndid) + + val_run = nd.get("DU_cluster") + dRun[val_run].append(ndid) +# assert ndparent.tag.endswith("TableCell"), "expected TableCell got %s" % nd.getparent().tag + + for fSimil in lfSimil: + _nOk, _nErr, _nMiss, _lFound, _lErr, _lMissed = evalPartitions( + list(dRun.values()) + , list(dGT.values()) + , fSimil + , jaccard_distance) + + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + + #traceln("simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + traceln("@simil %.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + )) + # , os.path.basename(sFilename))) + # sFilename = "" # ;-) + + # global count + nOk, nErr, nMiss = dOkErrMiss[fSimil] + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + + traceln() + g.doc.write(os.path.join(sRunDir, sFilename)+'.oracle') + + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL_TABLES @simil %.2f P %5.1f R %5.1f F1 %5.1f " % (fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + + return (nOk, nErr, nMiss) + +def eval_direct(lCriteria, lsDocDir + , bIgnoreHeader=False + , bIgnoreOutOfTable=True + , lfSimil=[i / 100.0 for i in [66, 80, 100]] + , xpSelector=".//pc:TextLine"): + """ + use the row, col, DU_row, DU_col XML attributes to form the partitions + + lCriteria is a list containg "row" or "col" or both + """ + assert lsDocDir + + llsFile = listParallelFiles(lsDocDir) + traceln("-loaded %d files for each criteria"%len(llsFile[0])) + + dOkErrMiss = { fSimil:(0,0,0) for fSimil in lfSimil } + + def _reverseDictionary(d): + rd = defaultdict(list) + for k, v in d.items(): + rd[v].append(k) + return rd + + for i, lsCritFile in enumerate(zip(*llsFile)): + assert len(lCriteria) == len(lsCritFile) + + # node_id -> consolidated_criteria_values + dIdValue = defaultdict(str) + dIdValue_GT = defaultdict(str) + for crit, sFilename, sDir in zip(lCriteria, lsCritFile, lsDocDir): + doc = etree.parse(os.path.join(sDir, sFilename)) + rootNd = doc.getroot() + assert len(PageXml.xpath(rootNd, "//pc:Page")) == 1, "NOT YET IMPLEMENTED: eval on multi-page files" + + for nd in PageXml.xpath(rootNd, xpSelector): + ndid = nd.get("id") + val_gt = nd.getparent().get(crit) + if val_gt is None: + if bIgnoreOutOfTable: + continue + else: + val_gt = "-1" + if bIgnoreHeader and nd.get("DU_header") != "D": continue + assert nd.getparent().tag.endswith("TableCell"), "expected TableCell got %s" % nd.getparent().tag + val = nd.get("DU_"+crit) +# import random +# if random.random() < 0.10: +# val = nd.get("DU_"+crit) +# else: +# val = nd.getparent().get(crit) + dIdValue[ndid] += "_%s_" % val + dIdValue_GT[ndid] += "_%s_" % val_gt +# print("**run ", str(dIdValue)) +# print("**GT ", str(dIdValue_GT)) + + # reverse dicitonaries + dValue_lId = _reverseDictionary(dIdValue) + dValue_lId_GT = _reverseDictionary(dIdValue_GT) + +# print("run ", list(dValue_lId.values())) +# print("GT ", list(dValue_lId_GT.values())) + for fSimil in lfSimil: + _nOk, _nErr, _nMiss = evalPartitions( + list(dValue_lId.values()) + , list(dValue_lId_GT.values()) + , fSimil + , jaccard_distance) + + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + + traceln("simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d %s" %( + fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + , os.path.basename(sFilename))) + sFilename = "" # ;-) + nOk, nErr, nMiss = dOkErrMiss[fSimil] + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + traceln() + + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL simil:%.2f P %5.1f R %5.1f F1 %5.1f " % (fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + + return (nOk, nErr, nMiss) + + +def eval_cluster(lsRunDir, sClusterLevel + , bIgnoreHeader=False + , bIgnoreOutOfTable=True + , lfSimil=[i / 100.0 for i in [66, 80, 100]] + , xpSelector=".//pc:TextLine" + , sAlgo=None + , sGroupByAttr=""): + """ + evaluate the cluster quality from a run folder + + We assume to have the groundtruth row and col in the files as well as the predicted clusters + """ + assert lsRunDir + dOkErrMiss = { fSimil:(0,0,0) for fSimil in lfSimil } + traceln(" --- eval_cluster level=%s"%sClusterLevel) + for sRunDir in lsRunDir: + lsFile = listFiles(sRunDir) + traceln("-loaded %d files from %s" % (len(lsFile), sRunDir)) + if not(lsFile) and not os.path.normpath(sRunDir).endswith("col"): + # ... checking folders + sRunDir = os.path.join(sRunDir, "col") + lsFile = listFiles(sRunDir) + if lsFile: + traceln("-loaded %d files from %s" % (len(lsFile), sRunDir)) + + if not sAlgo is None: + traceln("Loading cluster @algo='%s'"%sAlgo) + + for sFilename in lsFile: + doc = etree.parse(os.path.join(sRunDir, sFilename)) + rootNd = doc.getroot() + #assert len(PageXml.xpath(rootNd, "//pc:Page")) == 1, "NOT YET IMPLEMENTED: eval on multi-page files" + for iPage, ndPage in enumerate(PageXml.xpath(rootNd, "//pc:Page")): + traceln("PAGE %5d OF FILE %s" % (iPage+1, sFilename)) + # cluster -> [node_id] + dGT = defaultdict(list) + dRun = defaultdict(list) + for nd in PageXml.xpath(ndPage, xpSelector): + if bIgnoreHeader and nd.get("DU_header") != "D": continue + + ndparent = nd.getparent() + ndid = nd.get("id") + + if sClusterLevel == "cell": + val_gt = "%s__%s" % (ndparent.get("row"), ndparent.get("col")) + if val_gt == 'None__None' and bIgnoreOutOfTable: continue + elif sClusterLevel == "col": + val_gt = ndparent.get("col") + if val_gt == None and bIgnoreOutOfTable: continue + elif sClusterLevel == "row": + val_gt = ndparent.get("row") + if val_gt == None and bIgnoreOutOfTable: continue + elif sClusterLevel == 'region': + val_gt = "%s" % (ndparent.get("id")) + if val_gt == 'None' and bIgnoreOutOfTable: continue + else: + raise Exception("Unknown clustering level: %s"%sClusterLevel) + + # distinguish each table! + if sClusterLevel != 'region': + val_gt = "%s_%s" % (val_gt, ndparent.getparent().get("id")) + + dGT[val_gt].append(ndid) + + val_run = nd.get("DU_cluster") + dRun[val_run].append(ndid) + #assert ndparent.tag.endswith("TableCell"), "expected TableCell got %s" % nd.getparent().tag + + if not sAlgo is None: + dRun = defaultdict(list) + lNdCluster = PageXml.xpath(ndPage, ".//pc:Cluster[@algo='%s']"%sAlgo) + # lNdCluster = PageXml.xpath(ndPage, ".//pc:Cluster[@algo='%s' and @rowSpan='1']"%sAlgo) + traceln("Loaded %d cluster @algo='%s'"%(len(lNdCluster), sAlgo)) + for iCluster, ndCluster in enumerate(lNdCluster): + sIDs = ndCluster.get("content") + lndid = sIDs.split() + if lndid: + if sGroupByAttr: + # we group them by the value of an attribute + dRun[ndCluster.get(sGroupByAttr)].extend(lndid) + else: + dRun[str(iCluster)] = lndid + + for fSimil in lfSimil: + _nOk, _nErr, _nMiss = evalPartitions( + list(dRun.values()) + , list(dGT.values()) + , fSimil + , jaccard_distance) + + _fP, _fR, _fF = computePRF(_nOk, _nErr, _nMiss) + + #traceln("simil:%.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + traceln("@simil %.2f P %5.1f R %5.1f F1 %5.1f ok=%6d err=%6d miss=%6d" %( + fSimil + , _fP, _fR, _fF + , _nOk, _nErr, _nMiss + )) + # , os.path.basename(sFilename))) + # sFilename = "" # ;-) + + # global count + nOk, nErr, nMiss = dOkErrMiss[fSimil] + nOk += _nOk + nErr += _nErr + nMiss += _nMiss + dOkErrMiss[fSimil] = (nOk, nErr, nMiss) + + traceln() + + for fSimil in lfSimil: + nOk, nErr, nMiss = dOkErrMiss[fSimil] + fP, fR, fF = computePRF(nOk, nErr, nMiss) + traceln("ALL_TABLES @simil %.2f P %5.1f R %5.1f F1 %5.1f " % (fSimil, fP, fR, fF ) + , " " + ,"ok=%d err=%d miss=%d" %(nOk, nErr, nMiss)) + + return (nOk, nErr, nMiss) + + +# ------------------------------------------------------------------ +if __name__ == "__main__": + usage = """ + cut INPUT_FILE OUTPUT_FILE to cut + + eval_cut_row FILE+ eval of partitions from cuts + eval_cut_col FILE+ + eval_cut_cell ROW_DIR COL_DIR + eval_cut_gt_recall FILE+ maximum obtainablerecall by cutting + + eval_direct_row ROW_DIR eval of partitions from DU_row index + eval_direct_col COL_DIR + eval_direct_cell ROW_DIR COL_DIR + + eval_cluster RUN_DIR+ eval the quality of the CELLs clusters using the GT clusters defined by table@ID+@row+@col + eval_cluster_cell RUN_DIR+ eval the quality of the CELLs clusters using the GT clusters defined by table@ID+@row+@col + eval_cluster_col RUN_DIR+ eval the quality of the COLs clusters using the GT clusters defined by table@ID +@col + eval_cluster_row RUN_DIR+ eval the quality of the ROWs clusters using the GT clusters defined by table@ID+@row + --group_by_attr merge cluster according to the value of their ATTR attribute + + oracle_cluster_cell RUN_DIR+ eval the quality of the CELLs clusters using the GT clusters defined by table@ID+@row+@col + oracle_cluster_col RUN_DIR+ eval the quality of the COLs clusters using the GT clusters defined by table@ID +@col + oracle_cluster_row RUN_DIR+ eval the quality of the ROWs clusters using the GT clusters defined by table@ID+@row + if --algo is specified, then the run output is taken from the CLuster definitions + """ + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--cut-height", dest="fCutHeight", default=10 + , action="store", type=float + , help="cut, gt_recall: Minimal height of a cut") + + parser.add_option("--simil", dest="lfSimil", default=None + , action="append", type=float + , help="Minimal similarity for associating 2 partitions") + + parser.add_option("--cut-angle", dest='lsAngle' + , action="store", type="string", default="0" + ,help="cut, gt_recall: Allowed cutting angles, in degree, comma-separated") + + parser.add_option("--cut-below", dest='bCutBelow', action="store_true", default=False + ,help="cut, eval_row, eval_cell, gt_recall: (OBSOLETE) Each object defines one or several cuts above it (instead of above as by default)") + +# parser.add_option("--cut-above", dest='bCutAbove', action="store_true", default=None +# , help="Each object defines one or several cuts above it (instead of below as by default)") + + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true", default=False) + + parser.add_option("--algo", dest='sAlgo', action="store", type="string" + , help="Use the cluster definition by given algo, not the @DU_cluster attribute of the nodes") + + parser.add_option("--ignore-header", dest='bIgnoreHeader', action="store_true", default=False + , help="eval_row: ignore header text (and ignore empty cells, so ignore header cells!)") + + parser.add_option("--group_by_attr", dest='sGroupByAttr', action="store", type="string" + , help=" merge cluster according to the value of their ATTR attribute") + + parser.add_option("--ratio", dest='fRatio', action="store" + , type=float + , help="eval_col, eval_cell : Apply this ratio to the bounding box. This is normally useless as the baseline becomes a point (the centroid)" + , default=CutAnnotator.fRATIO) + + + # --- + #parse the command line + (options, args) = parser.parse_args() + + options.bCutAbove = not(options.bCutBelow) + + #load mpxml + try: + op = args[0] + except: + traceln(usage) + sys.exit(1) + + traceln("--- %s ---"%op) + if op in ["eval", "eval_row", "eval_col", "eval_cell"]: + if op == "eval": op = "eval_row" + traceln("DEPRECATED: now use ", op[0:4] + "_cut" + op[4:]) + exit(1) + # -------------------------------------- + if op == "cut": + sFilename = args[1] + sOutFilename = args[2] + traceln("- cutting : %s --> %s" % (sFilename, sOutFilename)) + lDegAngle = [float(s) for s in options.lsAngle.split(",")] + traceln("- Allowed angles (°): %s" % lDegAngle) + op_cut(sFilename, sOutFilename, lDegAngle, options.bCutAbove, fCutHeight=options.fCutHeight) + # -------------------------------------- + elif op.startswith("eval_cut"): + if op == "eval_cut_row": + lsFilename = args[1:] + traceln("- evaluating cut-based ROW partitions (fSimil=%s): " % options.lfSimil[0], lsFilename) + if options.bIgnoreHeader: traceln("- ignoring headers") + op_eval_row(lsFilename, options.lfSimil[0], options.bCutAbove, options.bVerbose + , bIgnoreHeader=options.bIgnoreHeader) + elif op == "eval_cut_col": + lsFilename = args[1:] + traceln("- evaluating cut-based COLUMN partitions (fSimil=%s): " % options.lfSimil[0], lsFilename) + op_eval_col(lsFilename, options.lfSimil[0], options.fRatio, options.bVerbose) + elif op == "eval_cut_cell": + sRowDir,sColDir = args[1:] + traceln("- evaluating cut-based CELL partitions : " , sRowDir,sColDir) + op_eval_cell(sRowDir, sColDir, options.fRatio, options.bCutAbove, options.bVerbose) + elif op == "eval_cut_gt_recall": + lsFilename = args[1:] + traceln("- GT recall : %s" % lsFilename) + lDegAngle = [float(s) for s in options.lsAngle.split(",")] + traceln("- Allowed angles (°): %s" % lDegAngle) + op_gt_recall(lsFilename, options.bCutAbove, lDegAngle, fCutHeight=options.fCutHeight) + else: + raise Exception("Unknown operation: %s"%op) + # -------------------------------------- + elif op.startswith("eval_direct"): + lCrit, lDir = [], [] + if options.bIgnoreHeader: traceln("- ignoring headers") + if op == "eval_direct_row": + sRowDir = args[1] + traceln("- evaluating ROW partitions (lfSimil=%s): " % options.lfSimil, sRowDir) + lCrit, lDir = ["row"], [sRowDir] + elif op == "eval_direct_col": + sColDir = args[1] + traceln("- evaluating COLUMN partitions (lfSimil=%s): " % options.lfSimil, sColDir) + lCrit, lDir = ["col"], [sColDir] + elif op == "eval_direct_cell": + sRowDir,sColDir = args[1:3] + lCrit, lDir = ["row", "col"], [sRowDir, sColDir] + traceln("- evaluating CELL partitions (lfSimil=%s): " % options.lfSimil, lCrit, " ", [sRowDir, sColDir]) + else: + raise Exception("Unknown operation: %s"%op) + if options.lfSimil: + eval_direct(lCrit, lDir + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + , lfSimil=options.lfSimil + ) + else: + eval_direct(lCrit, lDir + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + ) + + elif op.startswith("eval_cluster"): + lsRunDir = args[1:] + sClusterLevel = { "eval_cluster" :"cell" + , "eval_cluster_cell" :"cell" + , "eval_cluster_col" :"col" + , "eval_cluster_row" :"row" + , "eval_cluster_region":"region" + }[op] + traceln("- evaluating cluster partitions (lfSimil=%s): " % options.lfSimil, lsRunDir) + if options.sGroupByAttr: + traceln(" - merging run clusters having same value for their @%s attribute"%options.sGroupByAttr) + if options.lfSimil: + eval_cluster(lsRunDir, sClusterLevel + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + , lfSimil=options.lfSimil + , sAlgo=options.sAlgo + , sGroupByAttr=options.sGroupByAttr + ) + else: + eval_cluster(lsRunDir, sClusterLevel + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + , sAlgo=options.sAlgo + , sGroupByAttr=options.sGroupByAttr + ) + elif op.startswith("oracle"): + lsRunDir = args[1:] + sClusterLevel = { "oracle_cell" :"cell" + , "oracle_col" :"col" + , "oracle_row" :"row" + }[op] + traceln("- evaluating cluster partitions (lfSimil=%s): " % options.lfSimil, lsRunDir) + if options.lfSimil: + eval_oracle(lsRunDir, sClusterLevel + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + , lfSimil=options.lfSimil + ) + else: + eval_oracle(lsRunDir, sClusterLevel + , bIgnoreHeader=options.bIgnoreHeader + , bIgnoreOutOfTable=True + ) + + + # -------------------------------------- + else: + traceln(usage) + + + + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_GAT.py b/TranskribusDU/tasks/DU_Table/DU_Table_GAT.py new file mode 100644 index 0000000..6354493 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_GAT.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- + +""" + DU task for Table based on ECN + + Copyright NAVER(C) 2018, 2019 Hervé Déjean, Jean-Luc Meunier, Animesh Prasad + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from crf.Graph_MultiPageXml import Graph_MultiPageXml +from crf.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_woText, NodeType_PageXml_type +from crf.FeatureDefinition_PageXml_std import FeatureDefinition_PageXml_StandardOnes +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +from tasks.DU_ECN_Task import DU_ECN_Task + + +class DU_Table_GAT(DU_ECN_Task): + """ + ECN Models + """ + bHTR = True # do we have text from an HTR? + bPerPage = True # do we work per document or per page? + bTextLine = True # if False then act as TextRegion + + sMetadata_Creator = "NLE Document Understanding GAT" + + + sXmlFilenamePattern = "*.bar_mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.bar_mpxml" + + sLabeledXmlFilenameEXT = ".bar_mpxml" + + + dLearnerConfigOriginalGAT ={ + 'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 2,#2 Train Acc is lower 5 overfit both reach 81% accuracy on Fold-1 + 'nb_attention': 5, + 'stack_convolutions': True, + # 'node_indim': 50 , worked well 0.82 + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention': 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience": 50, + "mu": 0.00001, + "original_model" : True + + } + + + dLearnerConfigNewGAT = {'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 5, + 'nb_attention': 5, + 'stack_convolutions': True, + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention' : 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience":50, + "original_model": False, + "attn_type":0 + } + dLearnerConfig = dLearnerConfigNewGAT + #dLearnerConfig = dLearnerConfigOriginalGAT + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + lLabels = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + traceln("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + traceln(len(lLabels), lLabels) + traceln(len(lIgnoredLabels), lIgnoredLabels) + + + # DEFINING THE CLASS OF GRAPH WE USE + if cls.bPerPage: + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + else: + DU_GRAPH = Graph_MultiPageXml + + if cls.bHTR: + ntClass = NodeType_PageXml_type + else: + #ignore text + ntClass = NodeType_PageXml_type_woText + + + nt = ntClass("bar" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + nt.setLabelAttribute("DU_sem") + if cls.bTextLine: + nt.setXpathExpr( (".//pc:TextRegion/pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") + ) + else: + nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + + + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + if self.bHTR: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':500, 't_ngrams_node':(2,4), 'b_tfidf_node_lc':False + , 'n_tfidf_edge':250, 't_ngrams_edge':(2,4), 'b_tfidf_edge_lc':False } + else: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':None, 't_ngrams_node':None, 'b_tfidf_node_lc':None + , 'n_tfidf_edge':None, 't_ngrams_edge':None, 'b_tfidf_edge_lc':None } + + + if sComment is None: sComment = sModelName + + + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , dLearnerConfig= dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment=sComment + , cFeatureDefinition=cFeatureDefinition + , cModelClass=DU_Model_GAT + ) + + if options.bBaseline: + self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.bar_mpxml" + return DU_ECN_Task.predict(self, lsColDir) + diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Row_Edge.py b/TranskribusDU/tasks/DU_Table/DU_Table_Row_Edge.py new file mode 100644 index 0000000..72f294a --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Row_Edge.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +""" + DU task for segmenting text in cols using the conjugate graph after the SW + re-engineering by JLM + + Copyright NAVER(C) 2019 Jean-Luc Meunier + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) ))) ) + import TranskribusDU_version +TranskribusDU_version + +from tasks.DU_Table.DU_Table_Cell_Edge import main + + +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + main(sys.argv[0], "row") \ No newline at end of file diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_Separator_Annotator.py b/TranskribusDU/tasks/DU_Table/DU_Table_Separator_Annotator.py new file mode 100644 index 0000000..43c09a3 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_Separator_Annotator.py @@ -0,0 +1,195 @@ +# -*- coding: utf-8 -*- + +""" + USAGE: DU_Table_Separator_Annotator.py input-folder + + You must run this on your GT collection to create a training collection. + + If you pass a folder, you get a new folder with name postfixed by a_ + + It annotate the separator found in XML file as: + - S = a line separating consistently the items of a table + - I = all other lines + + For near-horizontal (i.e. more horizontal than vertical) lines, the maximum + @row+@row_span-1 of the items above the line must be strictly lesser than + the minimum @row of the item below the line. + + Copyright Naver Labs Europe 2019 + JL Meunier +""" + +import sys, os + +from lxml import etree +import shapely.geometry as geom + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import MultiPageXml +from util.Shape import ShapeLoader, ShapePartition + +DEBUG=False + +def isBaselineHorizontal(ndText): + lNdBaseline = MultiPageXml.getChildByName(ndText ,'Baseline') + if lNdBaseline: + try: + o = ShapeLoader.node_to_LineString(lNdBaseline[0]) + except: + return True + (minx, miny, maxx, maxy) = o.bounds + return bool(maxx-minx >= maxy-miny) + return True + +def main(lsFilename, lsOutFilename): + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + cnt, cntS = 0, 0 + for sFilename, sOutFilename in zip(lsFilename, lsOutFilename): + cntDoc, cntDocS = 0, 0 + + doc = etree.parse(sFilename, parser) + root = doc.getroot() + + # Separators are not under tableRegion... :-/ + lNdSep = MultiPageXml.getChildByName(root ,'SeparatorRegion') + loSep = [ShapeLoader.node_to_LineString(ndSep) for ndSep in lNdSep] + for _o in loSep: _o._bConsistent = True + + if not lNdSep: + traceln("Warning: no separator in %s"%sFilename) + else: + traceln("%25s %d separators" % (sFilename, len(lNdSep))) + lNdTR = MultiPageXml.getChildByName(root ,'TableRegion') + for ndTR in lNdTR: + lNdCells= MultiPageXml.getChildByName(ndTR ,'TableCell') + if not lNdCells: + continue + + nbRows = max(int(x.get('row')) for x in lNdCells) + + # build a list of Shapely objects augmented with our own table attributes + loText = [] # + for ndCell in lNdCells: + minRow = int(ndCell.get('row')) + minCol = int(ndCell.get('col')) + maxRow = minRow + int(ndCell.get('rowSpan')) - 1 + maxCol = minCol + int(ndCell.get('colSpan')) - 1 +# # ignore cell spanning the whole table height +# if maxRow >= nbRows: +# continue + for ndText in MultiPageXml.getChildByName(ndCell ,'TextLine'): + try: + oText = ShapeLoader.node_to_Polygon(ndText) + except: + traceln("WARNING: SKIPPING 1 TExtLine: cannot make a polygon from: %s" % etree.tostring(ndText)) + continue + # reflecting the textbox as a single point + (minx, miny, maxx, maxy) = oText.bounds + + # is the baseline horizontal or vertical?? + fDelta = min((maxx-minx) / 2.0, (maxy-miny) / 2.0) + if isBaselineHorizontal(ndText): + # supposed Horizontal text + oText = geom.Point(minx + fDelta , (maxy + miny)/2.0) + ndText.set("Horizontal", "TRUE") + + else: + ndText.set("Horizontal", "nope") + oText = geom.Point((minx + maxx)/2.0 , miny + fDelta) + + # considering it as a point, using its centroid + # does not work well due to loooong texts oText = oText.centroid + oText._minRow, oText._minCol = minRow, minCol + oText._maxRow, oText._maxCol = maxRow, maxCol + if DEBUG: oText._domnd = ndText + loText.append(oText) + + traceln(" TableRegion %d texts" % (len(loText))) + + if loText: + # checking in tun each separator for table-consistency + sp = ShapePartition(loText) + + for oSep in loSep: + (minx, miny, maxx, maxy) = oSep.bounds + if maxx - minx >= maxy - miny: + # supposed Horizontal + l = sp.getObjectAboveLine(oSep) + if l: + maxRowBefore = max(_o._maxRow for _o in l) + l = sp.getObjectBelowLine(oSep) + if l: + minRowAfter = min(_o._minRow for _o in l) + if maxRowBefore >= minRowAfter: oSep._bConsistent = False + else: + l1 = sp.getObjectOnLeftOfLine(oSep) + if l1: + maxColBefore = max(_o._maxCol for _o in l1) + l2 = sp.getObjectOnRightOfLine(oSep) + if l2: + minColAfter = min(_o._minCol for _o in l2) + if maxColBefore >= minColAfter: + oSep._bConsistent = False + if DEBUG: + # DEBUG + for o in l1: + if o._maxCol >= minColAfter: print("too much on right", etree.tostring(o._domnd)) + for o in l2: + if o._minCol <= maxColBefore: print("too much on left", etree.tostring(o._domnd)) + # end of TableRegion + # end of document + for ndSep, oSep in zip(lNdSep, loSep): + if oSep._bConsistent: + ndSep.set("DU_Sep", "S") + cntDocS += 1 + else: + ndSep.set("DU_Sep", "I") + cntDoc += 1 + + doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + traceln('%.2f%% consistent separators - annotation done for %s --> %s' % (100*float(cntDocS)/(cntDoc+0.000001), sFilename, sOutFilename)) + + del doc + cnt, cntS = cnt+cntDoc, cntS+cntDocS + traceln('%.2f%% consistent separators - annotation done for %d files' % (100*float(cntS)/(cnt+0.000001), cnt)) + + +if __name__ == "__main__": + try: + #we expect a folder + sInputDir = sys.argv[1] + if not os.path.isdir(sInputDir): raise Exception() + except IndexError: + traceln("Usage: %s " % sys.argv[0]) + exit(1) + + sOutputDir = "a_"+sInputDir + traceln(" - Output will be in ", sOutputDir) + try: + os.mkdir(sOutputDir) + os.mkdir(os.path.join(sOutputDir, "col")) + except: + pass + + lsFilename = [s for s in os.listdir(os.path.join(sInputDir, "col")) if s.endswith(".mpxml") ] + lsFilename.sort() + lsOutFilename = [os.path.join(sOutputDir, "col", "a_"+s) for s in lsFilename] + if not lsFilename: + lsFilename = [s for s in os.listdir(os.path.join(sInputDir, "col")) if s.endswith(".pxml") ] + lsFilename.sort() + lsOutFilename = [os.path.join(sOutputDir, "col", "a_"+s[:-5]+".mpxml") for s in lsFilename] + + lsInFilename = [os.path.join(sInputDir , "col", s) for s in lsFilename] + + traceln(lsFilename) + traceln("%d files to be processed" % len(lsFilename)) + + main(lsInFilename, lsOutFilename) diff --git a/TranskribusDU/tasks/DU_Table/DU_Table_direct.py b/TranskribusDU/tasks/DU_Table/DU_Table_direct.py new file mode 100644 index 0000000..005b4be --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/DU_Table_direct.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- + +""" + DU task: predicting directly the row or col number + + Copyright NAVER(C) 2019 Jean-Luc Meunier + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os +from shutil import copyfile +from collections import defaultdict +from lxml import etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from xml_formats.PageXml import PageXml +from tasks.DU_Task_Factory import DU_Task_Factory +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from graph.NodeType_PageXml import defaultBBoxDeltaFun +from graph.NodeType_PageXml import NodeType_PageXml_type_woText +from graph.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + +# ---------------------------------------------------------------------------- + + +sATTRIBUTE = "row" +iMIN, iMAX = 0, 15 +iMAXMAX = 99 +sXPATH = ".//pc:TextLine" + +bOTHER = True # do we need OTHER as label? + +# sXPATH = ".//pc:TextLine[../@%s]" % sATTRIBUTE + +""" +bad approach: the is forced to set the last loaded cell to row_9 +iMIN, iMAX = "0", "9" +sXPATH = ".//pc:TextLine[%s <= ../@%s and ../@%s <= %s]" % (iMIN, sATTRIBUTE, sATTRIBUTE, iMAX) +""" + +# ======= UTILITY ========= +def split_by_max(crit, sDir): + """ + here, we create sub-folders, where the files have the same number of row or col + """ + assert crit in ["row", "col"] + + sColDir= os.path.join(sDir, "col") + traceln("- looking at ", sColDir) + lsFile = [] + for _fn in os.listdir(sColDir): + _fnl = _fn.lower() + if _fnl.endswith("_du.mpxml") or _fnl.endswith("_du.pxml"): + continue + if not(_fnl.endswith(".mpxml") or _fnl.endswith(".pxml")): + continue + lsFile.append(_fn) + traceln(" %d files" % len(lsFile)) + + dCnt = defaultdict(int) + + for sFilename in lsFile: + trace("- %s" % sFilename) + sInFile = os.path.join(sColDir, sFilename) + doc = etree.parse(sInFile) + rootNd = doc.getroot() + vmax = -999 + xp = "//@%s" % crit + try: + vmax = max(int(_nd) for _nd in PageXml.xpath(rootNd, xp)) + assert vmax >= 0 + sToDir = "%s_%s_%d"%(sDir, crit, vmax) + except ValueError: + trace(" ERROR on file %s" % sInFile) + vmax = None + sToDir = "%s_%s_%s"%(sDir, crit, vmax) + del doc + sToColDir = os.path.join(sToDir, "col") + try: + os.mkdir(sToDir) + os.mkdir(sToColDir) + except FileExistsError: pass + copyfile(sInFile, os.path.join(sToColDir, sFilename)) + traceln(" -> ", sToColDir) + dCnt[vmax] += 1 + traceln("WARNING: %d invalid files"%dCnt[None]) + del dCnt[None] + traceln(sorted(dCnt.items())) + + +# ======= DOER ========= + +class My_NodeType_Exception(Exception): + pass + + +class My_NodeType(NodeType_PageXml_type_woText): + """ + We need this to extract properly the label from the label attribute of the (parent) TableCell element. + """ + sLabelAttr = sATTRIBUTE + + def __init__(self, sNodeTypeName, lsLabel, lsIgnoredLabel=None, bOther=True + , BBoxDeltaFun=defaultBBoxDeltaFun): + super(My_NodeType, self).__init__(sNodeTypeName, lsLabel + , lsIgnoredLabel=lsIgnoredLabel + , bOther=bOther + , BBoxDeltaFun=BBoxDeltaFun) + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + domnode = graph_node.node + sLabel = domnode.getparent().get(self.sLabelAttr) + if sLabel is None: + if self.bOther: + return self.getDefaultLabel() + else: + raise My_NodeType_Exception("Missing attribute @%s for node id=%s" % (self.sLabelAttr, domnode.get("id"))) + try: + i = int(sLabel) + if i > iMAX: + if self.bOther: + return self.getDefaultLabel() + else: + raise My_NodeType_Exception("Too large label integer value: @%s='%s' for node id=%s" % (self.sLabelAttr, sLabel, domnode.get("id"))) + else: + return sATTRIBUTE + "_" + sLabel + except: + raise My_NodeType_Exception("Invalid label value: @%s='%s' for node id=%s" % (self.sLabelAttr, sLabel, domnode.get("id"))) + + def setDocNodeLabel(self, graph_node, sLabel): + graph_node.node.set("DU_%s"%self.sLabelAttr, sLabel) + + +def getConfiguredGraphClass(doer): + """ + In this class method, we must return a configured graph class + """ + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + + # ntClass = NodeType_PageXml_type + ntClass = My_NodeType + + nt = ntClass(sATTRIBUTE # some short prefix because labels below are prefixed with it + , [str(n) for n in range(int(iMIN), int(iMAX)+1)] + # in --noother, we have a strict list of label values! + , ["%s_%d"%(sATTRIBUTE, n) for n in range(int(iMAX), int(iMAXMAX)+1)] if bOTHER else [] + , bOTHER # unused + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + #nt.setLabelAttribute("row") + nt.setXpathExpr( (sXPATH #how to find the nodes + #nt1.setXpathExpr( (".//pc:TableCell//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + # standard command line options for CRF- ECN- GAT-based methods + usage, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0]) + + traceln("VERSION: %s" % DU_Task_Factory.getVersion()) + + parser.add_option("--what", dest='sWhat' , action="store", type="string", default="row" + , help='what to predict, e.g. "row", "col")') + parser.add_option("--max", dest='iMax' , action="store", type="int" + , help="Maximum number to be found (starts at 0)") + parser.add_option("--intable", dest='bInTable' , action="store_true" + , help="Ignore TextLine ouside the table region, using the GT") + parser.add_option("--noother", dest='bNoOther' , action="store_true" + , help="No 'OTHER' label") + + # --- + #parse the command line + (options, args) = parser.parse_args() + + if args and args[0] == "split-by-max": + assert len(args) == 3, 'expected: split-by-max row|col ' + split_by_max(args[1], args[2]) + exit(0) + + + # standard arguments + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + DU_Task_Factory.exit(usage, 1, e) + + # specific options + if options.iMax: iMAX = options.iMax + if options.sWhat: sATTRIBUTE = options.sWhat + if bool(options.bInTable): sXPATH = sXPATH + "[../@%s]" % sATTRIBUTE + if options.bNoOther: bOTHER = False + + # some verbosity + traceln('Prediction "%s" from %d to %d using selector %s (%s)' % ( + sATTRIBUTE, 0, iMAX, sXPATH + , "With label 'OTHER'" if bOTHER else "Without label 'OTHER'" )) + + # standard options + doer = DU_Task_Factory.getDoer(sModelDir, sModelName + , options = options + , fun_getConfiguredGraphClass= getConfiguredGraphClass + , cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + , dFeatureConfig = {} + ) + + # setting the learner configuration, in a standard way + # (from command line options, or from a JSON configuration file) + # dLearnerConfig = doer.getStandardLearnerConfig(options) + if options.bECN: + dLearnerConfig = { + "name" :"default_8Lay1Conv", + "dropout_rate_edge" : 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node" : 0.2, + "lr" : 0.0001, + "mu" : 0.0001, + "nb_iter" : 3000, + "nconv_edge" : 1, + "node_indim" : -1, + "num_layers" : 8, + "ratio_train_val" : 0.1, + "patience" : 100, + "activation_name" :"relu", + "stack_convolutions" : False + } + elif options.bCRF: + dLearnerConfig = doer.getStandardLearnerConfig(options) + else: + raise "Unsupported method" + + if options.max_iter: + traceln(" - max_iter=%d" % options.max_iter) + dLearnerConfig["nb_iter"] = options.max_iter + + if False: + # force a balanced weighting + print("Forcing balanced weights") + dLearnerConfig['balanced'] = True + + doer.setLearnerConfiguration(dLearnerConfig) + + # act as per specified in the command line (--trn , --fold-run, ...) + doer.standardDo(options) + + del doer + diff --git a/TranskribusDU/tasks/DU_Table/__init__.py b/TranskribusDU/tasks/DU_Table/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TranskribusDU/tasks/DU_Table/cluster2Table.py b/TranskribusDU/tasks/DU_Table/cluster2Table.py new file mode 100644 index 0000000..2443fb6 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/cluster2Table.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- + +""" +Transform clusters into TableRegion/TableCells and populate them with TextLines + +Created on August 2019 + +Copyright NAVER LABS Europe 2019 +@author: Hervé Déjean +""" + +import sys, os, glob +from optparse import OptionParser +from copy import deepcopy +from collections import Counter +from collections import defaultdict + +from lxml import etree +import numpy as np +from shapely.ops import cascaded_union + + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) )) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from xml_formats.PageXml import PageXml +from util.Shape import ShapeLoader +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + +def getClusterCoords(lElts): + + lp = [] + for e in lElts: + try: + lp.append(ShapeLoader.node_to_Polygon(e)) + except ValueError: + pass + contour = cascaded_union([p if p.is_valid else p.convex_hull for p in lp ]) + # print(contour.wkt) + try:spoints = ' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in contour.convex_hull.exterior.coords) + except: + try: spoints = ' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in contour.convex_hull.coords) + # JL got once a: NotImplementedError: Multi-part geometries do not provide a coordinate sequence + except: spoints = "" + return spoints + +def deleteRegionsinDOM(page,lRegionsNd): + [page.remove(c) for c in lRegionsNd] + +def main(sInputDir + , bVerbose=False): + + lSkippedFile = [] + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith("_du.mpxml")] + traceln(" - %d .mpxml files to process" % len(lsFilename)) + for sMPXml in lsFilename: + traceln(" - .mpxml FILE : ", sMPXml) + + # 0 - load input file + doc = etree.parse(os.path.join(sInputDir,sMPXml)) + cluster2TableCell(doc,bVerbose) + +# doc.write(os.path.join(sInputDir,sMPXml), +# xml_declaration = True, +# encoding="utf-8", +# pretty_print=True +# #compression=0, #0 to 9 +# ) + + +def propagateTypeToRegion(ndRegion): + """ + compute the most frequent type in the Textlines and assigns it to the new region + """ + dType=Counter() + for t in ndRegion: + dType[t.get('type')]+=1 + mc = dType.most_common(1) + if mc : + if mc[0][0]:ndRegion.set('type',mc[0][0]) + # structure {type:page-number;} + # custom="structure {type:page-number;}" + if mc[0][0]:ndRegion.set('custom',"structure {type:%s;}"%mc[0][0]) + + +def addTableCellsToDom(page,ipage,lc,bVerbose): + """ + create a dom node for each cluster + update DU_cluster for each Textline + """ + # create TableRegion first ! + + for ic,dC in enumerate(lc): + ndRegion = PageXml.createPageXmlNode('TableCell') + + #update elements + lTL = lc[dC] + print (lTL) +# for id in c.get('content').split(): +# elt = page.xpath('.//*[@id="%s"]'%id)[0] +# elt.getparent().remove(elt) +# ndRegion.append(elt) +# lTL.append((elt)) + ndRegion.set('id',"p%d_r%d"%(ipage,ic)) + coords = PageXml.createPageXmlNode('Coords') + ndRegion.append(coords) + coords.set('points',getClusterCoords(lTL)) + propagateTypeToRegion(ndRegion) + + page.append(ndRegion) + +def getTextLines(ndPage): + lTL= ndPage.xpath(".//pg:TextLine", namespaces=dNS) + dIds= {} + for tl in lTL: dIds[tl.get('id')]=tl + return dIds + + +def getCellCluster(ndPage,xpCluster): + """ + + + """ + from statistics import mean + dColdict=defaultdict(list) + dRowdict=defaultdict(list) + dTmpRowIDPos=defaultdict(list) + dRowIDPos=defaultdict(list) + lClusters= ndPage.xpath(xpCluster, namespaces=dNS) + prevcolid=-1 + rowposition=0 + + dcol2=defaultdict(list) + # assume order by columns!!! + for c in lClusters: + name= c.get('name') + colid,rowid= [ int(i) for i in name.strip('()').split('_I_agglo_') if isinstance(int(i),int)] + dcol2[colid].append(c) + + #for c in lClusters: + for colid in dcol2: + c= dcol2[colid].sort(key=lambda cell:mean([ShapeLoader.node_to_Point(x).y for x in cell])) + name= c.get('name') + colid,rowid= [ int(i) for i in name.strip('()').split('_I_agglo_') if isinstance(int(i),int)] + # why? assume ordered by column?? + if colid != prevcolid:rowposition=-1 + rowposition += 1 + dColdict[colid].extend(c.get('content').split()) + dRowdict[rowid].extend(c.get('content').split()) + prevcolid = colid + + + for key, values in dTmpRowIDPos.items(): + print (key, max(set(values), key = values.count), values) + dRowIDPos[max(set(values), key = values.count)].append(key) #max(set(values), key = values.count) + + lF=defaultdict(list) + for i,pos in enumerate(sorted(dRowIDPos.keys())): + print (i,pos,dRowIDPos[pos], [dTmpRowIDPos[x] for x in dRowIDPos[pos] ]) + lF[i]= dRowIDPos[pos] + ss + #return dColdict,dRowdict,dRowIDPos #lCells + return dColdict,dRowdict,lF #lCells + + +def createTable(dColdict,dRowdict,dRowIDPos,lIds): + """ + + sort rows by avg(y) + """ + from statistics import mean + + #get table dimensions + for x in dRowIDPos.items(): print (x) + nbCols= len(dColdict.keys())+1 + nbRows= len(dRowIDPos.keys())+1 + table = [[ [] for i in range(nbCols)] for j in range(nbRows)] + print (nbRows,nbCols,len(table),len(table[0])) + for irow in sorted(dRowIDPos.keys()): + for jcol in sorted(dColdict.keys()): + if jcol > 0: +# print (irow,jcol) + # compute intersection?? + cellij= list([value for row in dRowIDPos[irow] for value in dRowdict[row] if value in dColdict[jcol] ]) + # print(dRowdict[dRowIDPos[irow]] ) + # print(dColdict[jcol]) + #print (irow,jcol-1,[''.join(lIds[id].itertext()).strip() for id in cellij]) + # print (irow,jcol-1,[lIds[id] for id in cellij]) + table[irow][jcol-1]=[lIds[id] for id in cellij] + + # ignore empty row +# table.sort(key=lambda row:mean([ShapeLoader.node_to_Point(x).y for cell in row for x in cell]) ) + for row in table: + print ([len(x) for cell in row for x in cell]) +# for col in row: +# print ([''.join(x.itertext()).strip() for x in col],end='') +# print() + table.sort(key=lambda row:mean([ShapeLoader.node_to_Point(x).y for cell in row for x in cell]) ) + +# for irow,row in enumerate(table): +# lY=[] +# for cell in row: +# if cell != []: +# #print( [ShapeLoader.node_to_Point(x) for x in cell]) +# mean([lY.extend(ShapeLoader.node_to_Point(x).y for x in cell)]) +# +# print (irow,mean(lY)) + + +def cluster2TableCell(doc, fTH=0.5,bVerbose=True): + """ + + """ + root = doc.getroot() + + + xpCluster = ".//pg:Cluster[@algo='(cut_I_agglo)']" + xpTextRegions = ".//pg:TextRegion" + + # get pages + for iPage, ndPage in enumerate(PageXml.xpath(root, "//pc:Page")[24:]): + # get cluster + dColdict,dRowdict,dRowIDPos = getCellCluster(ndPage,xpCluster) #ndPage.xpath(xpCluster, namespaces=dNS) + lIds = getTextLines(ndPage) +# lRegionsNd = ndPage.xpath(xpTextRegions, namespaces=dNS) +# if bVerbose:traceln("\n%d clusters and %d regions found" %(len(dClusters),len(lRegionsNd))) + try: + lCells = createTable(dColdict,dRowdict,dRowIDPos,lIds) + except KeyError: + print(iPage) + return +# addTableCellsToDom(ndPage,iPage+1,lCells,bVerbose) +# if bVerbose:traceln("%d regions created" %(len(dClusters))) +# deleteRegionsinDOM(ndPage, lRegionsNd) + return doc + + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Usage: %s + +""" % (sys.argv[0]) + + parser = OptionParser(usage=sUsage) + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + (options, args) = parser.parse_args() + + try: + sInputDir = args[0] + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + # all must be ok by now + lsDir = [sInputDir] + if not all(os.path.isdir(s) for s in lsDir): + for s in lsDir: + if not os.path.isdir(s): sys.stderr.write("Not a directory: %s\n"%s) + sys.exit(2) + bVerbose=options.bVerbose + main(sInputDir, bVerbose=options.bVerbose) + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/tasks/DU_Table/columnDetection.py b/TranskribusDU/tasks/DU_Table/columnDetection.py new file mode 100644 index 0000000..e4652b8 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/columnDetection.py @@ -0,0 +1,548 @@ +# -*- coding: utf-8 -*- +""" + + + build Table columns + + H. Déjean + + + copyright Naver 2018 + READ project + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. +""" + + + + +import sys, os.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))) + +from lxml import etree + +import common.Component as Component +from common.trace import traceln +import config.ds_xml_def as ds_xml +from ObjectModel.xmlDSDocumentClass import XMLDSDocument +from ObjectModel.XMLDSTEXTClass import XMLDSTEXTClass +from ObjectModel.XMLDSTABLEClass import XMLDSTABLEClass +from ObjectModel.XMLDSCELLClass import XMLDSTABLECELLClass +from ObjectModel.XMLDSTableRowClass import XMLDSTABLEROWClass +from spm.spmTableRow import tableRowMiner +from xml_formats.Page2DS import primaAnalysis +from xml_formats.DS2PageXml import DS2PageXMLConvertor + +class columnDetection(Component.Component): + """ + build table column + """ + usage = "" + version = "v.01" + description = "description: column Detection" + + #--- INIT ------------------------------------------------------------------------------------------------------------- + def __init__(self): + """ + Always call first the Component constructor. + """ + Component.Component.__init__(self, "columnDetection", self.usage, self.version, self.description) + + self.colname = None + self.docid= None + + self.do2DS= False + + # for --test + self.bCreateRef = False + self.evalData = None + + def setParams(self, dParams): + """ + Always call first the Component setParams + Here, we set our internal attribute according to a possibly specified value (otherwise it stays at its default value) + """ + Component.Component.setParams(self, dParams) +# if dParams.has_key("coldir"): +# self.colname = dParams["coldir"] + if "docid" in dParams: + self.docid = dParams["docid"] + if "dsconv" in dParams: + self.do2DS = dParams["dsconv"] + + if "createref" in dParams: + self.bCreateRef = dParams["createref"] + + + +# def createCells(self, table): +# """ +# create new cells using BIESO tags +# @input: tableObeject with old cells +# @return: tableObject with BIES cells +# @precondition: requires columns +# +# """ +# for col in table.getColumns(): +# lNewCells=[] +# # keep original positions +# col.resizeMe(XMLDSTABLECELLClass) +# for cell in col.getCells(): +# # print cell +# curChunk=[] +# lChunks = [] +# # print map(lambda x:x.getAttribute('type'),cell.getObjects()) +# # print map(lambda x:x.getID(),cell.getObjects()) +# cell.getObjects().sort(key=lambda x:x.getY()) +# for txt in cell.getObjects(): +# # print txt.getAttribute("type") +# if txt.getAttribute("type") == 'RS': +# if curChunk != []: +# lChunks.append(curChunk) +# curChunk=[] +# lChunks.append([txt]) +# elif txt.getAttribute("type") in ['RI', 'RE']: +# curChunk.append(txt) +# elif txt.getAttribute("type") == 'RB': +# if curChunk != []: +# lChunks.append(curChunk) +# curChunk=[txt] +# elif txt.getAttribute("type") == 'RO': +# ## add Other as well??? +# curChunk.append(txt) +# +# if curChunk != []: +# lChunks.append(curChunk) +# +# if lChunks != []: +# # create new cells +# table.delCell(cell) +# irow= cell.getIndex()[0] +# for i,c in enumerate(lChunks): +# # print map(lambda x:x.getAttribute('type'),c) +# #create a new cell per chunk and replace 'cell' +# newCell = XMLDSTABLECELLClass() +# newCell.setPage(cell.getPage()) +# newCell.setParent(table) +# newCell.setName(ds_xml.sCELL) +# newCell.setIndex(irow+i,cell.getIndex()[1]) +# newCell.setObjectsList(c) +# newCell.resizeMe(XMLDSTEXTClass) +# newCell.tagMe2() +# for o in newCell.getObjects(): +# o.setParent(newCell) +# o.tagMe() +# # table.addCell(newCell) +# lNewCells.append(newCell) +# cell.getNode().getparent().remove(cell.getNode()) +# del(cell) +# col.setObjectsList(lNewCells[:]) +# [table.addCell(c) for c in lNewCells] +# +# # print col.tagMe() + + + def createTable(self,page): + """ + BB of all elements? + """ + + def processPage(self,page): + from util.XYcut import mergeSegments + + ### skrinking to be done: + lCuts, x1, x2 = mergeSegments([(x.getX(),x.getX()+20,x) for x in page.getAllNamedObjects(XMLDSTEXTClass)],0) + for x,y,cut in lCuts: + ll =list(cut) + ll.sort(key=lambda x:x.getY()) + traceln(len(ll)) +# traceln (list(map(lambda x:x.getContent(),ll))) + + def findColumnsInDoc(self,ODoc): + """ + find columns for each table in ODoc + """ + self.lPages = ODoc.getPages() + + # not always? +# self.mergeLineAndCells(self.lPages) + + for page in self.lPages: + traceln("page: %d" % page.getNumber()) + self.processPage(page) + + def run(self,doc): + """ + load dom and find rows + """ + # conver to DS if needed + if self.bCreateRef: + if self.do2DS: + dsconv = primaAnalysis() + doc = dsconv.convert2DS(doc,self.docid) + + refdoc = self.createRef(doc) + return refdoc + # single ref per page + refdoc= self.createRefPerPage(doc) + return None + + if self.do2DS: + dsconv = primaAnalysis() + self.doc = dsconv.convert2DS(doc,self.docid) + else: + self.doc= doc + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(self.doc,listPages = range(self.firstPage,self.lastPage+1)) +# self.ODoc.loadFromDom(self.doc,listPages = range(30,31)) + + self.findColumnsInDoc(self.ODoc) + refdoc = self.createRef(self.doc) +# print refdoc.serialize('utf-8', 1) + + if self.do2DS: + # bakc to PageXml + conv= DS2PageXMLConvertor() + lPageXDoc = conv.run(self.doc) + conv.storeMultiPageXml(lPageXDoc,self.getOutputFileName()) +# print self.getOutputFileName() + return None + return self.doc + + + + ################ TEST ################## + + + def testRun(self, filename, outFile=None): + """ + evaluate using ABP new table dataset with tablecell + """ + + self.evalData=None + doc = self.loadDom(filename) + doc =self.run(doc) + self.evalData = self.createRef(doc) + if outFile: self.writeDom(doc) +# return self.evalData.serialize('utf-8',1) + return etree.tostring(self.evalData,encoding='unicode',pretty_print=True) + + + def overlapX(self,zone): + + + [a1,a2] = self.getX(),self.getX()+ self.getWidth() + [b1,b2] = zone.getX(),zone.getX()+ zone.getWidth() + return min(a2, b2) >= max(a1, b1) + + def overlapY(self,zone): + [a1,a2] = self.getY(),self.getY() + self.getHeight() + [b1,b2] = zone.getY(),zone.getY() + zone.getHeight() + return min(a2, b2) >= max(a1, b1) + def signedRatioOverlap(self,z1,z2): + """ + overlap self and zone + return surface of self in zone + """ + [x1,y1,h1,w1] = z1.getX(),z1.getY(),z1.getHeight(),z1.getWidth() + [x2,y2,h2,w2] = z2.getX(),z2.getY(),z2.getHeight(),z2.getWidth() + + fOverlap = 0.0 + + if self.overlapX(z2) and self.overlapY(z2): + [x11,y11,x12,y12] = [x1,y1,x1+w1,y1+h1] + [x21,y21,x22,y22] = [x2,y2,x2+w2,y2+h2] + + s1 = w1 * h1 + + # possible ? + if s1 == 0: s1 = 1.0 + + #intersection + nx1 = max(x11,x21) + nx2 = min(x12,x22) + ny1 = max(y11,y21) + ny2 = min(y12,y22) + h = abs(nx2 - nx1) + w = abs(ny2 - ny1) + + inter = h * w + if inter > 0 : + fOverlap = inter/s1 + else: + # if overX and Y this is not possible ! + fOverlap = 0.0 + + return fOverlap + + def findSignificantOverlap(self,TOverlap,ref,run): + """ + return + """ + pref,rowref= ref + prun, rowrun= run + if pref != prun: return False + + return rowref.ratioOverlap(rowrun) >=TOverlap + + + def testCPOUM(self, TOverlap, srefData, srunData, bVisual=False): + """ + TOverlap: Threshols used for comparing two surfaces + + + Correct Detections: + under and over segmentation? + """ + + cntOk = cntErr = cntMissed = 0 + + RefData = etree.XML(srefData.strip("\n").encode('utf-8')) + RunData = etree.XML(srunData.strip("\n").encode('utf-8')) +# try: +# RunData = libxml2.parseMemory(srunData.strip("\n"), len(srunData.strip("\n"))) +# except: +# RunData = None +# return (cntOk, cntErr, cntMissed) + lRun = [] + if RunData: + lpages = RunData.xpath('//%s' % ('PAGE')) + for page in lpages: + pnum=page.get('number') + #record level! + lRows = page.xpath(".//%s" % ("ROW")) + lORows = map(lambda x:XMLDSTABLEROWClass(0,x),lRows) + for row in lORows: + row.fromDom(row._domNode) + row.setIndex(row.getAttribute('id')) + lRun.append((pnum,row)) + print (lRun) + + lRef = [] + lPages = RefData.xpath('//%s' % ('PAGE')) + for page in lPages: + pnum=page.get('number') + lRows = page.xpath(".//%s" % ("ROW")) + lORows = map(lambda x:XMLDSTABLEROWClass(0,x),lRows) + for row in lORows: + row.fromDom(row._domNode) + row.setIndex(row.getAttribute('id')) + lRef.append((pnum,row)) + + + refLen = len(lRef) +# bVisual = True + ltisRefsRunbErrbMiss= list() + lRefCovered = [] + for i in range(0,len(lRun)): + iRef = 0 + bFound = False + bErr , bMiss= False, False + runElt = lRun[i] +# print '\t\t===',runElt + while not bFound and iRef <= refLen - 1: + curRef = lRef[iRef] + if runElt and curRef not in lRefCovered and self.findSignificantOverlap(TOverlap,runElt, curRef): + bFound = True + lRefCovered.append(curRef) + iRef+=1 + if bFound: + if bVisual:print("FOUND:", runElt, ' -- ', lRefCovered[-1]) + cntOk += 1 + else: + curRef='' + cntErr += 1 + bErr = True + if bVisual:print("ERROR:", runElt) + if bFound or bErr: + ltisRefsRunbErrbMiss.append( (int(runElt[0]), curRef, runElt,bErr, bMiss) ) + + for i,curRef in enumerate(lRef): + if curRef not in lRefCovered: + if bVisual:print("MISSED:", curRef) + ltisRefsRunbErrbMiss.append( (int(curRef[0]), curRef, '',False, True) ) + cntMissed+=1 + ltisRefsRunbErrbMiss.sort(key=lambda xyztu:xyztu[0]) + +# print cntOk, cntErr, cntMissed,ltisRefsRunbErrbMiss + return (cntOk, cntErr, cntMissed,ltisRefsRunbErrbMiss) + + + def testCompare(self, srefData, srunData, bVisual=False): + """ + as in Shahad et al, DAS 2010 + + Correct Detections + Partial Detections + Over-Segmented + Under-Segmented + Missed + False Positive + + """ + dicTestByTask = dict() + dicTestByTask['T50']= self.testCPOUM(0.50,srefData,srunData,bVisual) +# dicTestByTask['T75']= self.testCPOUM(0.750,srefData,srunData,bVisual) +# dicTestByTask['T100']= self.testCPOUM(0.50,srefData,srunData,bVisual) + + # dicTestByTask['FirstName']= self.testFirstNameRecord(srefData, srunData,bVisual) +# dicTestByTask['Year']= self.testYear(srefData, srunData,bVisual) + + return dicTestByTask + + def createColumnsWithCuts(self,lXCuts,table,tableNode,bTagDoc=False): + """ + create column dom node + """ + + prevCut = None + lXCuts.sort() + for index,cut in enumerate(lXCuts): + # first correspond to the table: no rpw + if prevCut is not None: + colNode= etree.Element("COL") + tableNode.append(colNode) + colNode.set('x',str(prevCut)) + colNode.set('width',"{:.2f}".format(cut - prevCut)) + colNode.set('y',str(table.getY())) + colNode.set('height',str(table.getHeight())) + colNode.set('id',str(index-1)) + prevCut= cut + + #last + cut=table.getX2() + colNode= etree.Element("COL") + tableNode.append(colNode) + colNode.set('x',"{:.2f}".format(prevCut)) + colNode.set('width',"{:.2f}".format(cut - prevCut)) + colNode.set('y',str(table.getY())) + colNode.set('height',str(table.getHeight())) + colNode.set('id',str(index)) + + + def createRef(self,doc): + """ + create a ref file from the xml one + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + + + for page in self.ODoc.getPages(): + #imageFilename="..\col\30275\S_Freyung_021_0001.jpg" width="977.52" height="780.0"> + pageNode = etree.Element('PAGE') + pageNode.set("number",page.getAttribute('number')) + pageNode.set("pagekey",os.path.basename(page.getAttribute('imageFilename'))) + pageNode.set("width",page.getAttribute('width')) + pageNode.set("height",page.getAttribute('height')) + + root.append(pageNode) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + dCol={} + tableNode = etree.Element('TABLE') + tableNode.set("x",table.getAttribute('x')) + tableNode.set("y",table.getAttribute('y')) + tableNode.set("width",table.getAttribute('width')) + tableNode.set("height",table.getAttribute('height')) + pageNode.append(tableNode) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + try:dCol[int(cell.getAttribute("col"))].append(cell) + except KeyError:dCol[int(cell.getAttribute("col"))] = [cell] + + lXcuts = [] + for colid in sorted(dCol.keys()): + lXcuts.append(min(list(map(lambda x:x.getX(),dCol[colid])))) + self.createColumnsWithCuts(lXcuts,table,tableNode) + + return refdoc + + def createRefPerPage(self,doc): + """ + create a ref file from the xml one + + for DAS 2018: one ref per graph(page) + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + + dRows={} + for page in self.ODoc.getPages(): + #imageFilename="..\col\30275\S_Freyung_021_0001.jpg" width="977.52" height="780.0"> + pageNode = etree.Element('PAGE') +# pageNode.set("number",page.getAttribute('number')) + #SINGLER PAGE pnum=1 + pageNode.set("number",'1') + + pageNode.set("imageFilename",page.getAttribute('imageFilename')) + pageNode.set("width",page.getAttribute('width')) + pageNode.set("height",page.getAttribute('height')) + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + root.append(pageNode) + + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + tableNode = etree.Element('TABLE') + tableNode.set("x",table.getAttribute('x')) + tableNode.set("y",table.getAttribute('y')) + tableNode.set("width",table.getAttribute('width')) + tableNode.set("height",table.getAttribute('height')) + pageNode.append(tableNode) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + try:dRows[int(cell.getAttribute("row"))].append(cell) + except KeyError:dRows[int(cell.getAttribute("row"))] = [cell] + + lYcuts = [] + for rowid in sorted(dRows.keys()): +# print rowid, min(map(lambda x:x.getY(),dRows[rowid])) + lYcuts.append(min(list(map(lambda x:x.getY(),dRows[rowid])))) + self.createRowsWithCuts(lYcuts,table,tableNode) + + + self.outputFileName = os.path.basename(page.getAttribute('imageFilename')[:-3]+'ref') + print(self.outputFileName) + self.writeDom(refdoc, bIndent=True) + + return refdoc + + # print refdoc.serialize('utf-8', True) +# self.testCPOUM(0.5,refdoc.serialize('utf-8', True),refdoc.serialize('utf-8', True)) + +if __name__ == "__main__": + + + rdc = columnDetection() + #prepare for the parsing of the command line + rdc.createCommandLineParser() +# rdc.add_option("--coldir", dest="coldir", action="store", type="string", help="collection folder") + rdc.add_option("--docid", dest="docid", action="store", type="string", help="document id") + rdc.add_option("--dsconv", dest="dsconv", action="store_true", default=False, help="convert page format to DS") + rdc.add_option("--createref", dest="createref", action="store_true", default=False, help="create REF file for component") + + rdc.add_option('-f',"--first", dest="first", action="store", type="int", help="first page to be processed") + rdc.add_option('-l',"--last", dest="last", action="store", type="int", help="last page to be processed") + + #parse the command line + dParams, args = rdc.parseCommandLine() + + #Now we are back to the normal programmatic mode, we set the component parameters + rdc.setParams(dParams) + + doc = rdc.loadDom() + doc = rdc.run(doc) + if doc is not None and rdc.getOutputFileName() != '-': + rdc.writeDom(doc, bIndent=True) + diff --git a/TranskribusDU/tasks/DU_Table/rowDetection.py b/TranskribusDU/tasks/DU_Table/rowDetection.py new file mode 100644 index 0000000..db21930 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/rowDetection.py @@ -0,0 +1,2159 @@ +# -*- coding: utf-8 -*- +""" + + + Build Rows for a BIESO model + + H. Déjean + + + copyright Xerox 2017, Naver 2017, 2018 + READ project + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. +""" + + + + +import sys, os.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))) + +import collections +from lxml import etree +from sklearn.metrics import adjusted_rand_score +from sklearn.metrics import homogeneity_score +from sklearn.metrics import completeness_score +import common.Component as Component +from common.trace import traceln +import config.ds_xml_def as ds_xml +from ObjectModel.xmlDSDocumentClass import XMLDSDocument +from ObjectModel.XMLDSTEXTClass import XMLDSTEXTClass +from ObjectModel.XMLDSTABLEClass import XMLDSTABLEClass +from ObjectModel.XMLDSCELLClass import XMLDSTABLECELLClass +from ObjectModel.XMLDSTableRowClass import XMLDSTABLEROWClass +from ObjectModel.XMLDSTableColumnClass import XMLDSTABLECOLUMNClass +from spm.spmTableRow import tableRowMiner +from xml_formats.Page2DS import primaAnalysis +from util.partitionEvaluation import evalPartitions, jaccard, iuo +from util.geoTools import sPoints2tuplePoints +from shapely.geometry import Polygon +from shapely import affinity +from shapely.ops import cascaded_union + + + +class RowDetection(Component.Component): + """ + row detection + @precondition: column detection done, BIES tagging done for text elements + + 11/9/2018: last idea: suppose the cell segmentation good enough: group cells which are unambiguous + with the cell in the (none empty) next column . + 12/11/2018: already done in mergehorinzontalCells !! + 12/11/2018: assume perfect cells: build simple: take next lright as same row + then look for elements belonging to several rows + """ + usage = "" + version = "v.1.1" + description = "description: rowDetection from BIO textlines" + + #--- INIT ------------------------------------------------------------------------------------------------------------- + def __init__(self): + """ + Always call first the Component constructor. + """ + Component.Component.__init__(self, "RowDetection", self.usage, self.version, self.description) + + self.colname = None + self.docid= None + + self.do2DS= False + + self.THHighSupport = 0.20 + self.bYCut = False + self.bCellOnly = False + # for --test + self.bCreateRef = False + self.bCreateRefCluster = False + + self.BTAG= 'B' + self.STAG = 'S' + self.bNoTable = False + self.bEvalCluster=False + self.evalData = None + + def setParams(self, dParams): + """ + Always call first the Component setParams + Here, we set our internal attribute according to a possibly specified value (otherwise it stays at its default value) + """ + Component.Component.setParams(self, dParams) +# if dParams.has_key("coldir"): +# self.colname = dParams["coldir"] + if "docid" in dParams: + self.docid = dParams["docid"] + if "dsconv" in dParams: + self.do2DS = dParams["dsconv"] + + if "createref" in dParams: + self.bCreateRef = dParams["createref"] + + + if "bNoColumn" in dParams: + self.bNoTable = dParams["bNoColumn"] + + if "createrefCluster" in dParams: + self.bCreateRefCluster = dParams["createrefCluster"] + + if "evalCluster" in dParams: + self.bEvalCluster = dParams["evalCluster"] + + if "thhighsupport" in dParams: + self.THHighSupport = dParams["thhighsupport"] * 0.01 + + if 'BTAG' in dParams: self.BTAG = dParams["BTAG"] + if 'STAG' in dParams: self.STAG = dParams["STAG"] + + if 'YCut' in dParams: self.bYCut = dParams["YCut"] + if 'bCellOnly' in dParams: self.bCellOnly = dParams["bCellOnly"] + + def createCells(self, table): + """ + create new cells using BIESO tags + @input: tableObeject with old cells + @return: tableObject with BIES cells + @precondition: requires columns + + if DU_col = M : ignore + + """ +# print ('nbcells:',len(table.getAllNamedObjects(XMLDSTABLECELLClass))) + table._lObjects = [] + lSkipped =[] + for col in table.getColumns(): +# print (col) + + lNewCells=[] + # keep original positions + try:col.resizeMe(XMLDSTABLECELLClass) + except: pass + # in order to ignore existing cells from GT: collect all objects from cells + lObjects = [txt for cell in col.getCells() for txt in cell.getObjects() ] + lObjects.sort(key=lambda x:x.getY()) + + curChunk=[] + lChunks = [] + for txt in lObjects: + # do no yse it for the moment + if txt.getAttribute("DU_col") == 'Mx': + lSkipped.append(txt) + elif txt.getAttribute("DU_row") == self.STAG: + if curChunk != []: + lChunks.append(curChunk) + curChunk=[] + lChunks.append([txt]) + elif txt.getAttribute("DU_row") in ['I', 'E']: + curChunk.append(txt) + elif txt.getAttribute("DU_row") == self.BTAG: + if curChunk != []: + lChunks.append(curChunk) + curChunk=[txt] + elif txt.getAttribute("DU_row") == 'O': + ## add Other as well??? no + curChunk.append(txt) +# pass + + if curChunk != []: + lChunks.append(curChunk) + + if lChunks != []: + # create new cells +# table.delCell(cell) + irow= txt.getParent().getIndex()[0] + for i,c in enumerate(lChunks): + #create a new cell per chunk and replace 'cell' + newCell = XMLDSTABLECELLClass() + newCell.setPage(txt.getParent().getPage()) + newCell.setParent(table) + newCell.setName(ds_xml.sCELL) +# newCell.setIndex(irow+i,txt.getParent().getIndex()[1]) + newCell.setIndex(i,txt.getParent().getIndex()[1]) + newCell.setObjectsList(c) +# newCell.addAttribute('type','new') + newCell.resizeMe(XMLDSTEXTClass) + newCell.tagMe2() + for o in newCell.getObjects(): + o.setParent(newCell) + o.tagMe() +# contour = self.createContourFromListOfElements(newCell.getObjects()) +# if contour is not None: +# # newCell.addAttribute('points',','.join("%s,%s"%(x[0],x[1]) for x in contour.lXY)) +# newCell.addAttribute('points',','.join("%s,%s"%(x[0],x[1]) for x in contour)) +# newCell.tagMe2() + +# table.addCell(newCell) + lNewCells.append(newCell) +# if txt.getParent().getNode().getparent() is not None: txt.getParent().getNode().getparent().remove(txt.getParent().getNode()) +# del(txt.getParent()) + #delete all cells + + for cell in col.getCells(): +# print (cell) + try: + if cell.getNode().getparent() is not None: cell.getNode().getparent().remove(cell.getNode()) + except: pass + + [table.delCell(cell) for cell in col.getCells() ] +# print ('\t nbcells 2:',len(table.getAllNamedObjects(XMLDSTABLECELLClass))) + col._lcells= [] + col._lObjects=[] +# print (col.getAllNamedObjects(XMLDSTABLECELLClass)) + [table.addCell(c) for c in lNewCells] + [col.addCell(c) for c in lNewCells] +# print ('\t nbcells 3:',len(table.getAllNamedObjects(XMLDSTABLECELLClass))) + +# print ('\tnbcells:',len(table.getAllNamedObjects(XMLDSTABLECELLClass))) + + + + def matchCells(self,table): + """ + use lcs (dtw?) for matching + dtw: detect merging situation + for each col: match with next col + + series 1 : col1 set of cells + series 2 : col2 set of cells + distance = Yoverlap + + """ + dBest = {} + #in(self.y2, tb.y2) - max(self.y1, tb.y1) + def distY(c1,c2): + o = min(c1.getY2() , c2.getY2()) - max(c1.getY() , c2.getY()) + if o < 0: + return 1 +# d = (2* (min(c1.getY2() , c2.getY2()) - max(c1.getY() , c2.getY()))) / (c1.getHeight() + c2.getHeight()) +# print(c1,c1.getY(),c1.getY2(), c2,c2.getY(),c2.getY2(),o,d) + return 1 - (1 * (min(c1.getY2() , c2.getY2()) - max(c1.getY() , c2.getY()))) / min(c1.getHeight() , c2.getHeight()) + + laErr=[] + for icol, col in enumerate(table.getColumns()): + lc = col.getCells() + laErr + lc.sort(key=lambda x:x.getY()) + if icol+1 < table.getNbColumns(): + col2 = table.getColumns()[icol+1] + if col2.getCells() != []: + cntOk,cntErr,cntMissed, lFound,lErr,lMissed = evalPartitions(lc,col2.getCells(), .25,distY) + [laErr.append(x) for x in lErr if x not in laErr] + [laErr.remove(x) for x,y in lFound if x in laErr] + # lErr: cell not matched in col1 + # lMissed: cell not matched in col2 + print (col,col2,cntOk,cntErr,cntMissed,lErr) #, lFound,lErr,lMissed) + for x,y in lFound: + dBest[x]=y + else: + [laErr.append(x) for x in lc if x not in laErr] + # create row + #sort keys by x + skeys = sorted(dBest.keys(),key=lambda x:x.getX()) + lcovered=[] + llR=[] + for key in skeys: +# print (key,lcovered) + if key not in lcovered: + lcovered.append(key) + nextC = dBest[key] +# print ("\t",key,nextC,lcovered) + lrow = [key] + while nextC: + lrow.append(nextC) + lcovered.append(nextC) + try: + nextC=dBest[nextC] + except KeyError: + print ('\txx\t',lrow) + llR.append(lrow) + nextC=None + + for lrow in llR: + contour = self.createContourFromListOfElements(lrow) + if contour is not None: + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + r = XMLDSTABLEROWClass(1) + r.setParent(table) + r.addAttribute('points',spoints) + r.tagMe('VV') + + def assessCuts(self,table,lYCuts): + """ + input: table, ycuts + output: + """ + # features or values ? + try:lYCuts = map(lambda x:x.getValue(),lYCuts) + except:pass + + lCells = table.getCells() + prevCut = table.getY() + irowIndex = 0 + lRows= [] + dCellOverlap = {} + for _,cut in enumerate(lYCuts): + row=[] + if cut - prevCut > 0: + [b1,b2] = prevCut, cut + for c in lCells: + [a1, a2] = c.getY(),c.getY() + c.getHeight() + if min(a2, b2) >= max(a1, b1): + row.append(c) + lRows.append(row) + irowIndex += 1 + prevCut = cut + + ## BIO coherence + + + + def buildLineCandidates(self,table): + """ + return a lits of lines corresponding to top row line candidates + """ + + + def mineTableRowPattern(self,table): + """ + find rows and columns patterns in terms of typographical position // mandatory cells,... + input: a set of rows (table) + action: seq mining of rows + output: pattern + + Mining at table/page level + # of cells per row + # of cells per colmun + # cell with content (j: freq ; i: freq) + + Sequential pattern:(itemset: setofrows; item cells?) + + """ + # which col is mandatory + # text alignment in cells (per col) + for row in table.getRows(): +# self.mineTypography() + a = row.computeSkewing() + + + """ + skewing detection: use synthetic data !! + simply scan row by row with previous row and adjust with coherence + + """ + + + def getSkewingRepresentation(self,lcuts): + """ + input: list of featureObject + output: skewed cut (a,b) + alog: for each feature: get the text nodes baselines and create a skewed line (a,b) + """ + + + + def miningSeparatorShape(self,table,lCuts): +# import numpy as np + from shapely.geometry import MultiLineString + for cut in lCuts: + xordered= list(cut.getNodes()) + print(cut,[x.getX() for x in xordered]) + xordered.sort(key = lambda x:x.getX()) + lSeparators = [ (x.getX(),x.getY()) for x in [xordered[0],xordered[-1]]] + print( lSeparators) + ml = MultiLineString(lSeparators) + print (ml.wkt) +# X = [x[0] for x in lSeparators] +# Y = [x[1] for x in lSeparators] +# print(X,Y) +# a, b = np.polynomial.polynomial.polyfit(X, Y, 1) +# xmin, xmax = table.getX(), table.getX2() +# y1 = a + b * xmin +# y2 = a + b * xmax +# print (y1,y2) +# print ([ (x.getObjects()[0].getBaseline().getY(),x.getObjects()[0].getBaseline().getAngle(),x.getY()) for x in xordered]) + + def processRows(self, table, predefinedCuts=[]): + """ + Apply mining to get Y cuts for rows + + If everything is centered? + Try thnum= [5,10,20,30,40,50] and keep better coherence! + + Then adjust skewing ? using features values: for c in lYcuts: print (c, [x.getY() for x in c.getNodes()]) + + replace columnMining by cell matching from col to col!! + simply best match (max overlap) between two cells NONONO + + perform chk of cells (tagging is now very good!) and use it for column mining (chk + remaining cells) + + """ +# self.matchCells(table) +# return + fMaxCoherence = 0.0 + rowMiner= tableRowMiner() + # % of columns needed + lTHSUP= [0.2,0.3,0.4] +# lTHSUP= [0.2] + bestTHSUP =None + bestthnum= None + bestYcuts = None + for thnum in [10,20,30]: # must be correlated with leading/text height? +# for thnum in [30]: # must be correlated with leading/text height? + +# for thnum in [50]: # must be correlated with leading/text height? + + """ + 07/1/2018: to be replace by HChunks + for each hchunks: % of cuts(beginning) = validate the top line as segmentor + ## hchunk at cell level : if yes select hchunks at textline level as well? + """ + lLYcuts = rowMiner.columnMining(table,thnum,lTHSUP,predefinedCuts) +# print (lLYcuts) + # get skewing represenation +# [ x.setValue(x.getValue()-0) for x in lYcuts ] + for iy,lYcuts in enumerate(lLYcuts): +# print ("%s %s " %(thnum, lTHSUP[iy])) +# lYcuts.sort(key= lambda x:x.getValue()) +# self.miningSeparatorShape(table,lYcuts) +# self.assessCuts(table, lYcuts) +# self.createRowsWithCuts2(table,lYcuts) + table.createRowsWithCuts(lYcuts) + table.reintegrateCellsInColRow() + coherence = self.computeCoherenceScore(table) + if coherence > fMaxCoherence: + fMaxCoherence = coherence + bestYcuts= lYcuts[:] + bestTHSUP = lTHSUP[iy] + bestthnum= thnum +# else: break +# print ('coherence Score for (%s,%s): %f\t%s'%(thnum,lTHSUP[iy],coherence,bestYcuts)) + if bestYcuts is not None: + ### create the separation with the hullcontour : row as polygon!! + ## if no intersection with previous row : OK + ## if intersection +# print (bestYcuts) +# for y in bestYcuts: +# ## get top elements of the cells to build the boundary ?? +# print ('%s %s'%(y.getValue(),[(c.getX(),c.getY()) for c in sorted(y.getNodes(),key=lambda x:x.getX())])) + ## what about elements outside the cut (beforeà) + ## try "skew option and evaluate""!! + ## take max -H + ## take skew + table.createRowsWithCuts(bestYcuts) + table.reintegrateCellsInColRow() + for row in table.getRows(): + row.addAttribute('points',"0,0") + contour = self.createContourFromListOfElements([x for c in row.getCells() for x in c.getObjects()]) + if contour is not None: + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + row.addAttribute('points',spoints) +# print (len(table.getPage().getAllNamedObjects(XMLDSTABLECELLClass))) + table.buildNDARRAY() +# self.mineTableRowPattern(table) + +# def defineRowTopBoundary(self,row,ycut): +# """ +# define a top row boundary +# """ + + def findBoundaryLinesFromChunks(self,table,lhckh): + """ + create lines from chunks (create with cells) + + take each chunk and create (a,b) with top contour + """ + + from util.Polygon import Polygon as dspp + import numpy as np + + dTop_lSgmt = collections.defaultdict(list) + for chk in lhckh: + sPoints = chk.getAttribute('points') #.replace(',',' ') + spoints = ' '.join("%s,%s"%((x,y)) for x,y in zip(*[iter(sPoints.split(','))]*2)) + it_sXsY = (sPair.split(',') for sPair in spoints.split(' ')) + plgn = dspp((float(sx), float(sy)) for sx, sy in it_sXsY) + try: + lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() + dTop_lSgmt[chk].extend(lT) + except ValueError: pass + #now make linear regression to draw relevant separators + def getX(lSegment): + lX = list() + for x1,y1,x2,y2 in lSegment: + lX.append(x1) + lX.append(x2) + return lX + + def getY(lSegment): + lY = list() + for x1,y1,x2,y2 in lSegment: + lY.append(y1) + lY.append(y2) + return lY + + + dAB = collections.defaultdict(list) + icmpt=0 + for icol, lSegment in dTop_lSgmt.items(): #sorted(dTop_lSgmt.items()): + print (icol,lSegment) + X = getX(lSegment) + Y = getY(lSegment) + #sum(l,()) + lfNorm = [np.linalg.norm([[x1,y1], [x2,y2]]) for x1,y1,x2,y2 in lSegment] + #duplicate each element + W = [fN for fN in lfNorm for _ in (0,1)] + + # a * x + b + a, b = np.polynomial.polynomial.polyfit(X, Y, 1, w=W) + xmin, xmax = min(X), max(X) + y1 = a + b * (0) + y2 = a + b * table.getX2() + dAB[b].append((a,b)) + rowline = XMLDSTABLEROWClass(icmpt) + rowline.setPage(table.getPage()) + rowline.setParent(table) + icmpt+=1 +# table.addColumn(rowline) # prevx1, prevymin,x1, ymin, x2, ymax, prevx2, prevymax)) + rowline.addAttribute('points',"%s,%s %s,%s"%(0, y1, table.getX2(),y2)) +# rowline.setX(prevxmin) +# rowline.setY(prevy1) +# rowline.setHeight(y2 - prevy1) +# rowline.setWidth(xmax- xmin) + rowline.tagMe('SeparatorRegion') + +# print (a,b) + + +# for b in sorted(dAB.keys()): +# print (b,dAB[b]) + + + + def processRows3(self,table,predefinedCuts=[] ): + """ + build rows: + for a given cell: if One single Y overlapping cell in the next column: integrate it in the row + + """ + from tasks.TwoDChunking import TwoDChunking + hchk = TwoDChunking() + lElts=[] + [lElts.append(x) for col in table.getColumns() for x in col.getCells()] + lhchk = hchk.HorizonalChunk(table.getPage(),lElts=lElts,bStrict=False) + +# lRows = [] +# curRow = [] +# for col in table.getColumns(): +# lcells = col.getCells() + + def processRows2(self,table,predefinedCuts=[]): + """ + Apply mining to get Y cuts for rows + + """ + from tasks.TwoDChunking import TwoDChunking + + hchk = TwoDChunking() + lhchk = hchk.HorizonalChunk(table.getPage(),lElts=table.getCells()) + + # create bounday lines from lhckh +# lYcuts = self.findBoundaryLinesFromChunks(table,lhchk) + +# lYcuts.sort(key= lambda x:x.getValue()) +# self.getSkewingRepresentation(lYcuts) +# self.assessCuts(table, lYcuts) +# self.createRowsWithCuts2(table,lYcuts) +# table.createRowsWithCuts(lYcuts) +# table.reintegrateCellsInColRow() +# +# table.buildNDARRAY() + + def checkInputFormat(self,lPages): + """ + delete regions : copy regions elements at page object + unlink subnodes + """ + for page in lPages: + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + lRegions = table.getAllNamedObjects("CELL") + lElts=[] + [lElts.extend(x.getObjects()) for x in lRegions] + [table.addObject(x,bDom=True) for x in lElts] + [table.removeObject(x,bDom=True) for x in lRegions] + + def processYCuts(self,ODoc): + from util.XYcut import mergeSegments + + self.checkInputFormat(ODoc.getPages()) + for page in ODoc.getPages(): + traceln("page: %d" % page.getNumber()) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + print ('nb Y: %s'% len(set([round(x.getY()) for x in page.getAllNamedObjects(XMLDSTEXTClass)])),len(page.getAllNamedObjects(XMLDSTEXTClass))) +# lCuts, _, _ = mergeSegments([(x.getY(),x.getY() + x.getHeight(),x) for x in page.getAllNamedObjects(XMLDSTEXTClass)],0) +# for i, (y,_,cut) in enumerate(lCuts): +# ll =list(cut) +# ll.sort(key=lambda x:x.getY()) +# #add column +# myRow= XMLDSTABLEROWClass(i) +# myRow.setPage(page) +# myRow.setParent(table) +# table.addObject(myRow) +# myRow.setY(y) +# myRow.setX(table.getX()) +# myRow.setWidth(table.getWidth()) +# if i +1 < len(lCuts): +# myRow.setHeight(lCuts[i+1][0]-y) +# else: # use table +# myRow.setHeight(table.getY2()-y) +# table.addRow(myRow) +# print (myRow) +# myRow.tagMe(ds_xml.sROW) + + + def mergeHorizontalCells(self,table): + """ + merge cell a to b|next col iff b overlap horizontally with a (using right border from points) + input: a table, with candidate cells + output: cluster of cells as row candidates + + + simply ignore cells which overlap several cells in the next column + then: extend row candidates if needed + + + if no column known: simply take the first cell in lright if cells in lright do ot X overlap (the first nearest w/o issue) + """ + # firtst create an index for hor neighbours + lNBNeighboursNextCol=collections.defaultdict(list) + lNBNeighboursPrevCol=collections.defaultdict(list) + for cell in table.getCells(): + # get next col + icol = cell.getIndex()[1] + if icol < table.getNbColumns()-1: + nextColCells=table.getColumns()[icol+1].getCells() + sorted(nextColCells,key=lambda x:x.getY()) + lHOverlap= [] + [lHOverlap.append(c) for c in nextColCells if cell.signedRatioOverlapY(c)> 1] + # if no overlap: take icol + 2 + lNBNeighboursNextCol[cell].extend(lHOverlap) + if icol > 1: + prevColCells=table.getColumns()[icol-1].getCells() + sorted(prevColCells,key=lambda x:x.getY()) + lHOverlap= [] + [lHOverlap.append(c) for c in prevColCells if cell.signedRatioOverlapY(c)> 1] + # if not overlap take icol-2 + lNBNeighboursPrevCol[cell].extend(lHOverlap) + + + lcovered=[] + for icol,col in enumerate(table.getColumns()): + sortedC = sorted(col.getCells(),key=lambda x:x.getY()) + for cell in sortedC: + if len(lNBNeighboursNextCol[cell]) < 2 and len(lNBNeighboursPrevCol[cell]) < 2: + if cell not in lcovered: + print(type(cell.getContent())) + print ('START :', icol,cell, cell.getContent(),cell.getY(),cell.getY2()) + lcovered.append(cell) + lcurRow = [cell] + iicol=icol + curCell = cell + while iicol < table.getNbColumns()-1: + nextColCells=table.getColumns()[iicol+1].getCells() + sorted(nextColCells,key=lambda x:x.getY()) + for c in nextColCells: + if len(lNBNeighboursNextCol[c]) < 2 and len(lNBNeighboursPrevCol[c]) < 2: + if curCell.signedRatioOverlapY(c) > 0.25 * curCell.getHeight(): + lcovered.append(c) + lcurRow.append(c) + print (curCell, curCell.getY(),curCell.getHeight(),c, curCell.signedRatioOverlapY(c),c.getY(), c.getHeight(),list(map(lambda x:x.getContent(),lcurRow))) + curCell = c + iicol +=1 + print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurRow)) ) + print ("\t", list(map(lambda x:x.getIndex(),lcurRow)) ) + if len(lcurRow)>1: + # create a contour for visualization + # order by col: get top and bottom polylines for them + contour = self.createContourFromListOfElements(lcurRow) + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + r = XMLDSTABLEROWClass(1) + r.setParent(table) + r.addAttribute('points',spoints) + r.tagMe('HH') + + +# def mergeHorizontalTextLines(self,table): +# """ +# merge text lines which are aligned +# input: a table, with candidate textlines +# output: cluster of textlines as row candidates +# +# """ +# from shapely.geometry import Polygon as pp +# from rtree import index +# +# cellidx = index.Index() +# lTexts = [] +# lPText=[] +# lReverseIndex = {} +# # Populate R-tree index with bounds of grid cells +# it=0 +# for cell in table.getCells(): +# for text in cell.getObjects(): +# tt = pp( [(text.getX(),text.getY()),(text.getX2(),text.getY()),(text.getX2(),text.getY2()), ((text.getX(),text.getY2()))] ) +# lTexts.append(text) +# lPText.append(tt) +# cellidx.insert(it, tt.bounds) +# it += 1 +# lReverseIndex[tt.bounds] = text +# +# lcovered=[] +# lfulleval= [] +# for text in lTexts: +# if text not in lcovered: +# # print ('START :', text, text.getContent()) +# lcovered.append(text) +# lcurRow = [text] +# curText= text +# while curText is not None: +# # print (curText, lcurRow) +# # sPoints = text.getAttribute('points') +# sPoints = curText.getAttribute('blpoints') +# # print (sPoints) +# # modify for creating aline to the right +# # take the most right X +# lastx,lasty = list([(float(x),float(y)) for x,y in zip(*[iter(sPoints.split(','))]*2)])[-1] +# # polytext = pp([(float(x),float(y)) for x,y in zip(*[iter(sPoints.split(','))]*2)]) +# polytext = pp([(lastx,lasty-10),(lastx+1000,lasty-10),(lastx+1000,lasty),(lastx,lasty)]) +# # print([(lastx,lasty-10),(lastx+1000,lasty-10),(lastx+1000,lasty),(lastx,lasty)]) +# ltover = [lPText[pos] for pos in cellidx.intersection(polytext.bounds)] +# ltover.sort(key=lambda x:x.centroid.coords[0]) +# lnextStep=[] +# # print ('\tnext\t',list(map(lambda x:lReverseIndex[x.bounds].getContent(),ltover))) +# +# for t1 in ltover: +# # here conditions: vertical porjection and Y overlap ; not area! +# if polytext.intersection(t1).area > 0.1: #t1.area*0.5: +# if t1 not in lnextStep and lReverseIndex[t1.bounds] not in lcovered: +# lnextStep.append(t1) +# if lnextStep != []: +# lnextStep.sort(key=lambda x:x.centroid.coords[0]) +# # print ('\t',list(map(lambda x:(lReverseIndex[x.bounds].getX(),lReverseIndex[x.bounds].getContent()),lnextStep))) +# nextt = lnextStep[0] +# lcurRow.append(lReverseIndex[nextt.bounds]) +# lcovered.append(lReverseIndex[nextt.bounds]) +# curText = lReverseIndex[nextt.bounds] +# else:curText = None +# +# # print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurRow)) ) +# # print ("FINAL", list(map(lambda x:(x,x.getParent()),lcurRow)) ) +# lfulleval.append(self.comptureClusterHomogeneity(lcurRow,0)) +# +# if len(lcurRow)>1: +# # create a contour for visualization +# # order by col: get top and bottom polylines for them +# contour = self.createContourFromListOfElements(lcurRow) +# spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) +# r = XMLDSTABLEROWClass(1) +# r.setParent(table) +# r.addAttribute('points',spoints) +# r.tagMe('VV') +# r.tagMe() +# +# print (sum(lfulleval)/len(lfulleval)) + + + def mergeHorVerTextLines(self,table): + """ + build HV lines + """ + from util import TwoDNeighbourhood as TwoDRel + lTexts = [] + if self.bNoTable: + lTexts = table.getAllNamedObjects(XMLDSTEXTClass) + else: + for cell in table.getCells(): + # bug to be fixed!! + if cell.getRowSpan() == 1 and cell.getColSpan() == 1: + lTexts.extend(set(cell.getObjects())) + + for e in lTexts: + e.lright=[] + e.lleft=[] + e.ltop=[] + e.lbottom=[] + lVEdge = TwoDRel.findVerticalNeighborEdges(lTexts) + for a,b in lVEdge: + a.lbottom.append( b ) + b.ltop.append(a) + for elt in lTexts: + # dirty! + elt.setHeight(max(5,elt.getHeight()-3)) + elt.setWidth(max(5,elt.getWidth()-3)) + TwoDRel.rotateMinus90degOLD(elt) + lHEdge = TwoDRel.findVerticalNeighborEdges(lTexts) + for elt in lTexts: +# elt.tagMe() + TwoDRel.rotatePlus90degOLD(elt) +# return + for a,b in lHEdge: + a.lright.append( b ) + b.lleft.append(a) +# ss + for elt in lTexts: + elt.lleft.sort(key = lambda x:x.getX(),reverse=True) +# elt.lright.sort(key = lambda x:x.getX()) + if len(elt.lright) > 1: + elt.lright = [] + elt.lright.sort(key = lambda x:elt.signedRatioOverlapY(x),reverse=True) +# print (elt, elt.getY(), elt.lright) + elt.ltop.sort(key = lambda x:x.getY()) + if len(elt.lbottom) >1: + elt.lbottom = [] + elt.lbottom.sort(key = lambda x:elt.signedRatioOverlapX(x),reverse=True) + + + + # Horizontal + lTexts.sort(key = lambda x:x.getX()) + lcovered=[] + lfulleval = [] + for text in lTexts: + if text not in lcovered: +# print ('START :', text, text.getContent()) + lcovered.append(text) + lcurRow = [text] + curText= text + while curText is not None: + try: + nextT = curText.lright[0] +# print ('\t',[(x,curText.signedRatioOverlapY(x)) for x in curText.lright]) + if nextT not in lcovered: + lcurRow.append(nextT) + lcovered.append(nextT) + curText = nextT + except IndexError:curText = None + +# print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurRow)) ) +# lfulleval.append(self.comptureClusterHomogeneity(lcurRow,0)) + if len(lcurRow) > 1: + # create a contour for visualization + # order by col: get top and bottom polylines for them + contour = self.createContourFromListOfElements(lcurRow) + if contour is not None: + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + r = XMLDSTABLEROWClass(1) + r.setParent(table) + r.addAttribute('points',spoints) + r.tagMe('HH') +# print (sum(lfulleval)/len(lfulleval)) + + + # Vertical + lTexts.sort(key = lambda x:x.getY()) + lcovered=[] + lfulleval = [] + for text in lTexts: + if text not in lcovered: +# print ('START :', text, text.getContent()) + lcovered.append(text) + lcurCol = [text] + curText= text + while curText is not None: + try: + nextT = curText.lbottom[0] +# print ('\t',[(x,curText.signedRatioOverlapY(x)) for x in curText.lright]) + if nextT not in lcovered and len(nextT.lbottom) == 1: + lcurCol.append(nextT) + lcovered.append(nextT) + curText = nextT + except IndexError:curText = None + +# print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurCol)) ) +# lfulleval.append(self.comptureClusterHomogeneity(lcurCol,1)) + if len(lcurCol)>1: + # create a contour for visualization + # order by col: get top and bottom polylines for them + contour = self.createContourFromListOfElements(lcurCol) + if contour is not None: + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + r = XMLDSTABLEROWClass(1) + r.setParent(table) + r.addAttribute('points',spoints) +# r.setDimensions(...) + r.tagMe('VV') +# print (sum(lfulleval)/len(lfulleval)) + + + def mergeHorVerCells(self,table): + """ + build HV chunks cells + """ + from util import TwoDNeighbourhood as TwoDRel + lTexts = [] + for cell in table.getCells(): + # bug to be fixed!! + if cell.getRowSpan() == 1 and cell.getColSpan() == 1: +# lTexts.extend(set(cell.getObjects())) + lTexts.append(cell) + + for e in lTexts: + e.lright=[] + e.lleft=[] + e.ltop=[] + e.lbottom=[] + lVEdge = TwoDRel.findVerticalNeighborEdges(lTexts) + for a,b in lVEdge: + a.lbottom.append( b ) + b.ltop.append(a) + for elt in lTexts: + # dirty! + elt.setHeight(max(5,elt.getHeight()-3)) + elt.setWidth(max(5,elt.getWidth()-3)) + TwoDRel.rotateMinus90degOLD(elt) + lHEdge = TwoDRel.findVerticalNeighborEdges(lTexts) + for elt in lTexts: +# elt.tagMe() + TwoDRel.rotatePlus90degOLD(elt) +# return + for a,b in lHEdge: + a.lright.append( b ) + b.lleft.append(a) +# ss + for elt in lTexts: + elt.lleft.sort(key = lambda x:x.getX(),reverse=True) +# elt.lright.sort(key = lambda x:x.getX()) + elt.lright.sort(key = lambda x:elt.signedRatioOverlapY(x),reverse=True) + if len(elt.lright) >1: + elt.lright = [] +# print (elt, elt.getY(), elt.lright) + elt.ltop.sort(key = lambda x:x.getY()) + elt.lbottom.sort(key = lambda x:elt.signedRatioOverlapX(x),reverse=True) + + + # Horizontal + lTexts.sort(key = lambda x:x.getX()) + lcovered=[] + lfulleval = [] + for text in lTexts: + if text not in lcovered: +# print ('START :', text, text.getContent()) + lcovered.append(text) + lcurRow = [text] + curText= text + while curText is not None: + try: + nextT = curText.lright[0] +# print ('\t',[(x,curText.signedRatioOverlapY(x)) for x in curText.lright]) + if nextT not in lcovered: + lcurRow.append(nextT) + lcovered.append(nextT) + curText = nextT + except IndexError:curText = None + + print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurRow)) ) +# lfulleval.append(self.comptureClusterHomogeneity(lcurRow,0)) + if len(lcurRow) > 1: + # create a contour for visualization + # order by col: get top and bottom polylines for them + contour = self.createContourFromListOfElements(lcurRow) + if contour is not None: + spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) + r = XMLDSTABLEROWClass(1) + r.setParent(table) + r.addAttribute('points',spoints) + r.tagMe('HH') +# print (sum(lfulleval)/len(lfulleval)) + + +# # Vertical +# lTexts.sort(key = lambda x:x.getY()) +# lcovered=[] +# lfulleval = [] +# for text in lTexts: +# if text not in lcovered: +# # print ('START :', text, text.getContent()) +# lcovered.append(text) +# lcurCol = [text] +# curText= text +# while curText is not None: +# try: +# nextT = curText.lbottom[0] +# # print ('\t',[(x,curText.signedRatioOverlapY(x)) for x in curText.lright]) +# if nextT not in lcovered: +# lcurCol.append(nextT) +# lcovered.append(nextT) +# curText = nextT +# except IndexError:curText = None +# +# # print ("FINAL", list(map(lambda x:(x,x.getContent()),lcurRow)) ) +# lfulleval.append(self.comptureClusterHomogeneity(lcurCol,1)) +# if len(lcurCol)>1: +# # create a contour for visualization +# # order by col: get top and bottom polylines for them +# contour = self.createContourFromListOfElements(lcurCol) +# if contour is not None: +# spoints = ','.join("%s,%s"%(x[0],x[1]) for x in contour) +# r = XMLDSTABLEROWClass(1) +# r.setParent(table) +# r.addAttribute('points',spoints) +# r.tagMe('VV') +# print (sum(lfulleval)/len(lfulleval)) + + def createContourFromListOfElements(self, lElts): + """ + create a polyline from a list of elements + input : list of elements + output: Polygon object + """ + from shapely.geometry import Polygon as pp + from shapely.ops import cascaded_union + lP = [] + for elt in lElts: + + sPoints = elt.getAttribute('points') + if sPoints is None: + lP.append(pp([(elt.getX(),elt.getY()),(elt.getX(),elt.getY2()), (elt.getX2(),elt.getY2()),(elt.getX2(),elt.getY())] )) + else: + lP.append(pp([(float(x),float(y)) for x,y in zip(*[iter(sPoints.split(','))]*2)])) + try:ss = cascaded_union(lP) + except ValueError: +# print(lElts,lP) + return None + if not ss.is_empty: + return list(ss.convex_hull.exterior.coords) + else: return None + + + def comptureClusterHomogeneity(self,c,dir): + """ + % of elements belonging to the same structre + dir: 0 : row, 1 column + """ + + ldict = collections.defaultdict(list) + [ ldict[elt.getParent().getIndex()[dir]].append(elt) for elt in c] + lstat = ([(k,len(ldict[k])) for k in ldict]) + total = sum([x[1] for x in lstat]) + leval = (max(([len(ldict[x])/total for x in ldict]))) + return leval + + def findRowsInDoc(self,ODoc): + """ + find rows for each table in document + input: a document + output: a document where tables have rows + """ + from tasks.TwoDChunking import TwoDChunking + + self.lPages = ODoc.getPages() +# hchk = TwoDChunking() + # not always? +# self.mergeLineAndCells(self.lPages) + + for page in self.lPages: + traceln("page: %d" % page.getNumber()) +# print (len(page.getAllNamedObjects(XMLDSTABLECELLClass))) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + # col as polygon + self.getPolylinesForRowsColumns(table) +# self.getPolylinesForRows(table) +# rowscuts = list(map(lambda r:r.getY(),table.getRows())) + rowscuts=[] +# traceln ('initial cuts:',rowscuts) + self.createCells(table) +# lhchk = hchk.HorizonalChunk(page,lElts=table.getCells()) +# hchk.VerticalChunk(page,tag=XMLDSTEXTClass) +# self.mergeHorizontalCells(table) + # then merge overlaping then sort Y and index : then insert ambiguous textlines +# self.mergeHorizontalTextLines(table) +# self.mergeHorVerTextLines(table) +# self.processRows3(table) + if self.bCellOnly: + continue +# self.mergeHorizontalCells(table) +# self.mergeHorVerCells(table) + self.processRows(table,rowscuts) +# self.mineTableRowPattern(table) + table.tagMe() + if self.bNoTable: + self.mergeHorVerTextLines(page) + + + +# def extendLines(self,table): +# """ +# Extend textlines up to table width using baseline +# input:table +# output: table with extended baselines +# """ +# for col in table.getColumns(): +# for cell in col.getCells(): +# for elt in cell.getObjects(): +# if elt.getWidth()> 100: +# #print ([ (x.getObjects()[0].getBaseline().getY(),x.getObjects()[0].getBaseline().getAngle(),x.getY()) for x in xordered]) +# print (elt,elt.getBaseline().getAngle(), elt.getBaseline().getBx(),elt.getBaseline().getPoints()) +# newBl = [(table.getX(),elt.getBaseline().getAngle()* table.getX() + elt.getBaseline().getBx()), +# (table.getX2(),elt.getBaseline().getAngle()* table.getX2() + elt.getBaseline().getBx()) +# ] +# elt.getBaseline().setPoints(newBl) +# myPoints = '%f,%f,%f,%f'%(newBl[0][0],newBl[0][1],newBl[1][0],newBl[1][1]) +# elt.addAttribute('blpoints',myPoints) + + +# sys.exit(0) + + def getPolylinesForRowsColumns(self,table): + """ + input: list of cells (=table) + output: columns defined by polylines (not Bounding box) + """ + import numpy as np + from util.Polygon import Polygon + from shapely.geometry import Polygon as pp +# from shapely.ops import cascaded_union + from rtree import index + + cellidx = index.Index() + lCells = [] + lReverseIndex = {} + # Populate R-tree index with bounds of grid cells + for pos, cell in enumerate(table.getCells()): + # assuming cell is a shapely object + cc = pp( [(cell.getX(),cell.getY()),(cell.getX2(),cell.getY()),(cell.getX2(),cell.getY2()), ((cell.getX(),cell.getY2()))] ) + lCells.append(cc) + cellidx.insert(pos, cc.bounds) + lReverseIndex[cc.bounds] = cell + + + dColSep_lSgmt = collections.defaultdict(list) + dRowSep_lSgmt = collections.defaultdict(list) + for cell in table.getCells(): + row, col, rowSpan, colSpan = [int(cell.getAttribute(sProp)) for sProp \ + in ["row", "col", "rowSpan", "colSpan"] ] + sPoints = cell.getAttribute('points') #.replace(',',' ') +# print (cell,sPoints) + spoints = ' '.join("%s,%s"%((x,y)) for x,y in zip(*[iter(sPoints.split(','))]*2)) + it_sXsY = (sPair.split(',') for sPair in spoints.split(' ')) + plgn = Polygon((float(sx), float(sy)) for sx, sy in it_sXsY) +# print (plgn.getBoundingBox(),spoints) + try: + lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() + #now the top segments contribute to row separator of index: row + dRowSep_lSgmt[row].extend(lT) + dRowSep_lSgmt[row+rowSpan].extend(lB) + dColSep_lSgmt[col].extend(lL) + dColSep_lSgmt[col+colSpan].extend(lR) + except ValueError: pass + + #now make linear regression to draw relevant separators + def getX(lSegment): + lX = list() + for x1,y1,x2,y2 in lSegment: + lX.append(x1) + lX.append(x2) + return lX + + def getY(lSegment): + lY = list() + for x1,y1,x2,y2 in lSegment: + lY.append(y1) + lY.append(y2) + return lY + + prevx1 , prevx2 , prevymin , prevymax = None,None,None,None #table.getX(),table.getX(),table.getY(),table.getY2() + + + # erase columns: + table.eraseColumns() + icmpt=0 + for icol, lSegment in sorted(dColSep_lSgmt.items()): + X = getX(lSegment) + Y = getY(lSegment) + #sum(l,()) + lfNorm = [np.linalg.norm([[x1,y1], [x2,y2]]) for x1,y1,x2,y2 in lSegment] + #duplicate each element + W = [fN for fN in lfNorm for _ in (0,1)] + + # a * x + b + a, b = np.polynomial.polynomial.polyfit(Y, X, 1, w=W) + + ymin, ymax = min(Y), max(Y) + x1 = a + b * ymin + x2 = a + b * ymax + if prevx1: + col = XMLDSTABLECOLUMNClass() + col.setPage(table.getPage()) + col.setParent(table) + col.setIndex(icmpt) + icmpt+=1 + table.addColumn(col) + col.addAttribute('points',"%s,%s %s,%s,%s,%s %s,%s"%(prevx1, prevymin,x1, ymin, x2, ymax, prevx2, prevymax)) + col.setX(prevx1) + col.setY(prevymin) + col.setHeight(ymax- ymin) + col.setWidth(x2-prevx1) + col.tagMe() +# from shapely.geometry import Polygon as pp + polycol = pp([(prevx1, prevymin),(x1, ymin), (x2, ymax), (prevx2, prevymax)] ) +# print ((prevx1, prevymin),(x1, ymin), (x2, ymax), (prevx2, prevymax)) +# colCells = cascaded_union([cells[pos] for pos in cellidx.intersection(polycol.bounds)]) + colCells = [lCells[pos] for pos in cellidx.intersection(polycol.bounds)] + for cell in colCells: + try: + if polycol.intersection(cell).area > cell.area*0.5: + col.addCell(lReverseIndex[cell.bounds]) + except: + pass + + prevx1 , prevx2 , prevymin , prevymax = x1, x2, ymin, ymax + + + def getPolylinesForRows(self,table): + """ + input: list of candidate cells (=table) + output: "rows" defined by top polylines + """ + import numpy as np + from util.Polygon import Polygon + from shapely.geometry import Polygon as pp +# from shapely.ops import cascaded_union + from rtree import index + + cellidx = index.Index() + lCells = [] + lReverseIndex = {} + # Populate R-tree index with bounds of grid cells + for pos, cell in enumerate(table.getCells()): + # assuming cell is a shapely object + cc = pp( [(cell.getX(),cell.getY()),(cell.getX2(),cell.getY()),(cell.getX2(),cell.getY2()), ((cell.getX(),cell.getY2()))] ) + lCells.append(cc) + cellidx.insert(pos, cc.bounds) + lReverseIndex[cc.bounds] = cell + + + dColSep_lSgmt = collections.defaultdict(list) + dRowSep_lSgmt = collections.defaultdict(list) + for cell in table.getCells(): + row, col, rowSpan, colSpan = [int(cell.getAttribute(sProp)) for sProp \ + in ["row", "col", "rowSpan", "colSpan"] ] + sPoints = cell.getAttribute('points') #.replace(',',' ') +# print (cell,sPoints) + spoints = ' '.join("%s,%s"%((x,y)) for x,y in zip(*[iter(sPoints.split(','))]*2)) + it_sXsY = (sPair.split(',') for sPair in spoints.split(' ')) + plgn = Polygon((float(sx), float(sy)) for sx, sy in it_sXsY) +# print (plgn.getBoundingBox(),spoints) + try: + lT, lR, lB, lL = plgn.partitionSegmentTopRightBottomLeft() + #now the top segments contribute to row separator of index: row + dRowSep_lSgmt[row].extend(lT) + dRowSep_lSgmt[row+rowSpan].extend(lB) + dColSep_lSgmt[col].extend(lL) + dColSep_lSgmt[col+colSpan].extend(lR) + except ValueError: pass + + #now make linear regression to draw relevant separators + def getX(lSegment): + lX = list() + for x1,y1,x2,y2 in lSegment: + lX.append(x1) + lX.append(x2) + return lX + + def getY(lSegment): + lY = list() + for x1,y1,x2,y2 in lSegment: + lY.append(y1) + lY.append(y2) + return lY + + prevxmin , prevxmax , prevy1 , prevy2 = None,None,None,None #table.getX(),table.getX(),table.getY(),table.getY2() + + + # erase columns: + table.eraseColumns() + icmpt=0 + for _, lSegment in sorted(dRowSep_lSgmt.items()): + X = getX(lSegment) + Y = getY(lSegment) + #sum(l,()) + lfNorm = [np.linalg.norm([[x1,y1], [x2,y2]]) for x1,y1,x2,y2 in lSegment] + #duplicate each element + W = [fN for fN in lfNorm for _ in (0,1)] + + # a * x + b + a, b = np.polynomial.polynomial.polyfit(X, Y, 1, w=W) + xmin, xmax = min(X), max(X) + y1 = a + b * xmin + y2 = a + b * xmax + + if prevy1: + col = XMLDSTABLEROWClass(icmpt) + col.setPage(table.getPage()) + col.setParent(table) + icmpt+=1 + table.addColumn(col) # prevx1, prevymin,x1, ymin, x2, ymax, prevx2, prevymax)) + col.addAttribute('points',"%s,%s %s,%s,%s,%s %s,%s"%(prevxmin, prevy1, prevxmax,prevy2, xmax,y2, prevxmax,y1)) + col.setX(prevxmin) + col.setY(prevy1) + col.setHeight(y2 - prevy1) + col.setWidth(xmax- xmin) + col.tagMe() +# from shapely.geometry import Polygon as pp +# polycol = pp([(prevx1, prevymin),(x1, ymin), (x2, ymax), (prevx2, prevymax)] ) +# # print ((prevx1, prevymin),(x1, ymin), (x2, ymax), (prevx2, prevymax)) +# # colCells = cascaded_union([cells[pos] for pos in cellidx.intersection(polycol.bounds)]) +# colCells = [lCells[pos] for pos in cellidx.intersection(polycol.bounds)] +# for cell in colCells: +# if polycol.intersection(cell).area > cell.area*0.5: +# col.addCell(lReverseIndex[cell.bounds]) + + + prevy1 , prevy2 , prevxmin , prevxmax = y1, y2, xmin, xmax + + for cell in table.getCells(): + del cell._lAttributes['points'] + + + def testscale(self,ltexts): + return + for t in ltexts: + if True or t.getAttribute('id')[-4:] == '1721': +# print (t) + # print (etree.tostring(t.getNode())) + shrinked = affinity.scale(t.toPolygon(),3,-0.8) + # print (list(t.toPolygon().exterior.coords), list(shrinked.exterior.coords)) + ss = ",".join(["%s,%s"%(x,y) for x,y in shrinked.exterior.coords]) + # print (ss) + t.getNode().set("points",ss) + # print (etree.tostring(t.getNode())) + + + + + def testshapely(self,Odoc): + for page in Odoc.lPages: + self.testscale(page.getAllNamedObjects(XMLDSTEXTClass)) + traceln("page: %d" % page.getNumber()) +# lTables = page.getAllNamedObjects(XMLDSTABLEClass) +# for table in lTables: +# table.testPopulate() + + def run(self,doc): + """ + load dom and find rows + """ + # conver to DS if needed + if self.bCreateRef: + if self.do2DS: + dsconv = primaAnalysis() + doc = dsconv.convert2DS(doc,self.docid) + + refdoc = self.createRef(doc) + return refdoc + # single ref per page +# refdoc= self.createRefPerPage(doc) +# return None + + if self.bCreateRefCluster: + if self.do2DS: + dsconv = primaAnalysis() + doc = dsconv.convert2DS(doc,self.docid) + +# refdoc = self.createRefCluster(doc) + refdoc = self.createRefPartition(doc) + + return refdoc + + if self.do2DS: + dsconv = primaAnalysis() + self.doc = dsconv.convert2DS(doc,self.docid) + else: + self.doc= doc + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(self.doc,listPages = range(self.firstPage,self.lastPage+1)) +# self.testshapely(self.ODoc) +# # self.ODoc.loadFromDom(self.doc,listPages = range(30,31)) + if self.bYCut: + self.processYCuts(self.ODoc) + else: + self.findRowsInDoc(self.ODoc) + return self.doc + + + + def computeCoherenceScore(self,table): + """ + input: table with rows, BIEOS tagged textlines + BIO now ! + output: coherence score + + coherence score: float + percentage of textlines those BIESO tagged is 'coherent with the row segmentation' + + """ + coherenceScore = 0 + nbTotalTextLines = 0 + for row in table.getRows(): + for cell in row.getCells(): + nbTextLines = len(cell.getObjects()) + nbTotalTextLines += nbTextLines + if nbTextLines == 1 and cell.getObjects()[0].getAttribute("DU_row") == self.STAG: coherenceScore+=1 + else: + for ipos, textline in enumerate(cell.getObjects()): + if ipos == 0: + if textline.getAttribute("DU_row") in [self.BTAG]: coherenceScore += 1 + else: + if textline.getAttribute("DU_row") in ['I']: coherenceScore += 1 +# if ipos == nbTextLines-1: +# if textline.getAttribute("DU_row") in ['E']: coherenceScore += 1 +# if ipos not in [0, nbTextLines-1]: +# if textline.getAttribute("DU_row") in ['I']: coherenceScore += 1 + + if nbTotalTextLines == 0: return 0 + else : return coherenceScore /nbTotalTextLines + + ################ TEST ################## + + + def testRun(self, filename, outFile=None): + """ + evaluate using ABP new table dataset with tablecell + """ + + self.evalData=None + doc = self.loadDom(filename) + doc =self.run(doc) + if self.bEvalCluster: + self._evalData = self.createRunPartition( self.ODoc) +# self.evalData = self.createRefCluster(doc) + else: + self.evalData = self.createRef(doc) + if outFile: self.writeDom(doc) + return etree.tostring(self._evalData,encoding='unicode',pretty_print=True) + + + + + + def testCluster(self, srefData, srunData, bVisual=False): + """ + + + + + + + + + + + + NEED to work at page level !!?? + then average? + """ + cntOk = cntErr = cntMissed = 0 + + RefData = etree.XML(srefData.strip("\n").encode('utf-8')) + RunData = etree.XML(srunData.strip("\n").encode('utf-8')) + + lPages = RefData.xpath('//%s' % ('PAGE[@number]')) + lRefKeys={} + dY = {} + lY={} + dIDMap={} + for page in lPages: + pnum=page.get('number') + key=page.get('pagekey') + dIDMap[key]={} + lY[key]=[] + dY[key]={} + xpath = ".//%s" % ("R") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for i,row in enumerate(lrows): + xpath = ".//@id" + lids = row.xpath(xpath) + for id in lids: + # with spanning an element can belong to several rows? + if id not in dY[key]: + dY[key][id]=i + lY[key].append(i) + dIDMap[key][id]=len(lY[key])-1 + try:lRefKeys[key].append((pnum,key,lids)) + except KeyError:lRefKeys[key] = [(pnum,key,lids)] + rand_score = completeness = homogen_score = 0 + if RunData is not None: + lpages = RunData.xpath('//%s' % ('PAGE[@number]')) + for page in lpages: + pnum=page.get('number') + key=page.get('pagekey') + if key in lRefKeys: + lX=[-1 for i in range(len(dIDMap[key]))] + xpath = ".//%s" % ("ROW") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for i,row in enumerate(lrows): + xpath = ".//@id" + lids = row.xpath(xpath) + for id in lids: + lX[ dIDMap[key][id]] = i + + #adjusted_rand_score(ref,run) + rand_score += adjusted_rand_score(lY[key],lX) + completeness += completeness_score(lY[key], lX) + homogen_score += homogeneity_score(lY[key], lX) + + ltisRefsRunbErrbMiss= list() + return (rand_score/len(lRefKeys), completeness/len(lRefKeys), homogen_score/len(lRefKeys),ltisRefsRunbErrbMiss) + + + def testGeometry(self, th, srefData, srunData, bVisual=False): + """ + compare geometrical zones (dtw + iou) + :param + + returns tuple (cntOk, cntErr, cntMissed,ltisRefsRunbErrbMiss + + """ + + cntOk = cntErr = cntMissed = 0 + ltisRefsRunbErrbMiss = list() + RefData = etree.XML(srefData.strip("\n").encode('utf-8')) + RunData = etree.XML(srunData.strip("\n").encode('utf-8')) + + lPages = RefData.xpath('//%s' % ('PAGE[@number]')) + + for ip,page in enumerate(lPages): + lY=[] + key=page.get('pagekey') + xpath = ".//%s" % ("ROW") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for col in lrows: + xpath = ".//@points" + lpoints = col.xpath(xpath) + colgeo = cascaded_union([ Polygon(sPoints2tuplePoints(p)) for p in lpoints]) + if lpoints != []: + lY.append(colgeo) + + if RunData is not None: + lpages = RunData.xpath('//%s' % ('PAGE[@pagekey="%s"]' % key)) + lX=[] + if lpages != []: + for page in lpages[0]: + xpath = ".//%s" % ("ROW") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for col in lrows: + xpath = ".//@points" + lpoints = col.xpath(xpath) + if lpoints != []: + lX.append( Polygon(sPoints2tuplePoints(lpoints[0]))) + lX = list(filter(lambda x:x.is_valid,lX)) + ok , err , missed,lfound,lerr,lmissed = evalPartitions(lX, lY, th,iuo) + cntOk += ok + cntErr += err + cntMissed +=missed + [ltisRefsRunbErrbMiss.append((ip, y1.bounds, x1.bounds,False, False)) for (x1,y1) in lfound] + [ltisRefsRunbErrbMiss.append((ip, y1.bounds, None,False, True)) for y1 in lmissed] + [ltisRefsRunbErrbMiss.append((ip, None, x1.bounds,True, False)) for x1 in lerr] + +# ltisRefsRunbErrbMiss.append(( lfound, ip, ok,err, missed)) +# print (key, cntOk , cntErr , cntMissed) + return (cntOk , cntErr , cntMissed,ltisRefsRunbErrbMiss) + + def testCluster2(self, th, srefData, srunData, bVisual=False): + """ + + +
+ + + + + + + + + NEED to work at page level !!?? + then average? + """ + RefData = etree.XML(srefData.strip("\n").encode('utf-8')) + RunData = etree.XML(srunData.strip("\n").encode('utf-8')) + + lPages = RefData.xpath('//%s' % ('PAGE[@number]')) + for page in lPages: + lY=[] + key=page.get('pagekey') + xpath = ".//%s" % ("ROW") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for row in lrows: + xpath = ".//@id" + lid = row.xpath(xpath) + if lid != []: + lY.append(lid) +# print (row.xpath(xpath)) + + if RunData is not None: + lpages = RunData.xpath('//%s' % ('PAGE[@pagekey="%s"]' % key)) + lX=[] + for page in lpages[:1]: + xpath = ".//%s" % ("ROW") + lrows = page.xpath(xpath) + if len(lrows) > 0: + for row in lrows: + xpath = ".//@id" + lid = row.xpath(xpath) + if lid != []: + lX.append( lid) + cntOk , cntErr , cntMissed,lf,le,lm = evalPartitions(lX, lY, th,jaccard) +# print ( cntOk , cntErr , cntMissed) + ltisRefsRunbErrbMiss= list() + return (cntOk , cntErr , cntMissed,ltisRefsRunbErrbMiss) + + + def overlapX(self,zone): + + + [a1,a2] = self.getX(),self.getX()+ self.getWidth() + [b1,b2] = zone.getX(),zone.getX()+ zone.getWidth() + return min(a2, b2) >= max(a1, b1) + + def overlapY(self,zone): + [a1,a2] = self.getY(),self.getY() + self.getHeight() + [b1,b2] = zone.getY(),zone.getY() + zone.getHeight() + return min(a2, b2) >= max(a1, b1) + def signedRatioOverlap(self,z1,z2): + """ + overlap self and zone + return surface of self in zone + """ + [x1,y1,h1,w1] = z1.getX(),z1.getY(),z1.getHeight(),z1.getWidth() + [x2,y2,h2,w2] = z2.getX(),z2.getY(),z2.getHeight(),z2.getWidth() + + fOverlap = 0.0 + + if self.overlapX(z2) and self.overlapY(z2): + [x11,y11,x12,y12] = [x1,y1,x1+w1,y1+h1] + [x21,y21,x22,y22] = [x2,y2,x2+w2,y2+h2] + + s1 = w1 * h1 + + # possible ? + if s1 == 0: s1 = 1.0 + + #intersection + nx1 = max(x11,x21) + nx2 = min(x12,x22) + ny1 = max(y11,y21) + ny2 = min(y12,y22) + h = abs(nx2 - nx1) + w = abs(ny2 - ny1) + + inter = h * w + if inter > 0 : + fOverlap = inter/s1 + else: + # if overX and Y this is not possible ! + fOverlap = 0.0 + + return fOverlap + + def findSignificantOverlap(self,TOverlap,ref,run): + """ + return + """ + pref,rowref= ref + prun, rowrun= run + if pref != prun: return False + + return rowref.ratioOverlap(rowrun) >=TOverlap + + + def testCPOUM(self, TOverlap, srefData, srunData, bVisual=False): + """ + TOverlap: Threshols used for comparing two surfaces + + + Correct Detections: + under and over segmentation? + """ + + cntOk = cntErr = cntMissed = 0 + + RefData = etree.XML(srefData.strip("\n").encode('utf-8')) + RunData = etree.XML(srunData.strip("\n").encode('utf-8')) +# try: +# RunData = libxml2.parseMemory(srunData.strip("\n"), len(srunData.strip("\n"))) +# except: +# RunData = None +# return (cntOk, cntErr, cntMissed) + lRun = [] + if RunData is not None: + lpages = RunData.xpath('//%s' % ('PAGE')) + for page in lpages: + pnum=page.get('number') + #record level! + lRows = page.xpath(".//%s" % ("ROW")) + lORows = map(lambda x:XMLDSTABLEROWClass(0,x),lRows) + for row in lORows: + row.fromDom(row._domNode) + row.setIndex(row.getAttribute('id')) + lRun.append((pnum,row)) +# print (lRun) + + lRef = [] + lPages = RefData.xpath('//%s' % ('PAGE')) + for page in lPages: + pnum=page.get('number') + lRows = page.xpath(".//%s" % ("ROW")) + lORows = map(lambda x:XMLDSTABLEROWClass(0,x),lRows) + for row in lORows: + row.fromDom(row._domNode) + row.setIndex(row.getAttribute('id')) + lRef.append((pnum,row)) + + + refLen = len(lRef) +# bVisual = True + ltisRefsRunbErrbMiss= list() + lRefCovered = [] + for i in range(0,len(lRun)): + iRef = 0 + bFound = False + bErr , bMiss= False, False + runElt = lRun[i] +# print '\t\t===',runElt + while not bFound and iRef <= refLen - 1: + curRef = lRef[iRef] + if runElt and curRef not in lRefCovered and self.findSignificantOverlap(TOverlap,runElt, curRef): + bFound = True + lRefCovered.append(curRef) + iRef+=1 + if bFound: + if bVisual:print("FOUND:", runElt, ' -- ', lRefCovered[-1]) + cntOk += 1 + else: + curRef='' + cntErr += 1 + bErr = True + if bVisual:print("ERROR:", runElt) + if bFound or bErr: + ltisRefsRunbErrbMiss.append( (int(runElt[0]), curRef, runElt,bErr, bMiss) ) + + for i,curRef in enumerate(lRef): + if curRef not in lRefCovered: + if bVisual:print("MISSED:", curRef) + ltisRefsRunbErrbMiss.append( (int(curRef[0]), curRef, '',False, True) ) + cntMissed+=1 + ltisRefsRunbErrbMiss.sort(key=lambda xyztu:xyztu[0]) + +# print cntOk, cntErr, cntMissed,ltisRefsRunbErrbMiss + return (cntOk, cntErr, cntMissed,ltisRefsRunbErrbMiss) + + + def testCompare(self, srefData, srunData, bVisual=False): + """ + as in Shahad et al, DAS 2010 + + Correct Detections + Partial Detections + Over-Segmented + Under-Segmented + Missed + False Positive + + """ + dicTestByTask = dict() + if self.bEvalCluster: +# dicTestByTask['CLUSTER']= self.testCluster(srefData,srunData,bVisual) + dicTestByTask['CLUSTER100']= self.testCluster2(1.0,srefData,srunData,bVisual) + dicTestByTask['CLUSTER90']= self.testCluster2(0.9,srefData,srunData,bVisual) + dicTestByTask['CLUSTER80']= self.testCluster2(0.8,srefData,srunData,bVisual) +# dicTestByTask['CLUSTER50']= self.testCluster2(0.5,srefData,srunData,bVisual) + + else: + dicTestByTask['T80']= self.testGeometry(0.50,srefData,srunData,bVisual) +# dicTestByTask['T50']= self.testCPOUM(0.50,srefData,srunData,bVisual) + + + return dicTestByTask + + + def createRowsWithCuts2(self,table,lYCuts): + """ + input: lcells, horizontal lcuts + output: list of rows populated with appropriate cells (main overlap) + + Algo: create cell chunks and determine (a,b) for the cut (a.X +b = Y) + does not solve everything ("russian mountains" in weddings) + """ + from tasks.TwoDChunking import TwoDChunking + if lYCuts == []: + return + + #reinit rows + self._lrows = [] + + #build horizontal chunks + hchk = TwoDChunking() + hchk.HorizonalChunk(table.getPage(),tag=XMLDSTABLECELLClass) + + +# #get all texts +# lTexts = [] +# [ lTexts.extend(colcell.getObjects()) for col in table.getColumns() for colcell in col.getObjects()] +# lTexts.sort(lambda x:x.getY()) +# +# #initial Y: table top border +# prevCut = self.getY() +# +# # ycuts: features or float +# try:lYCuts = map(lambda x:x.getValue(),lYCuts) +# except:pass +# +# itext = 0 +# irowIndex = 0 +# lrowcells = [] +# lprevrowcells = [] +# prevRowCoherenceScore = 0 +# for irow,cut in enumerate(lYCuts): +# yrow = prevCut +# y2 = cut +# h = cut - prevCut +# lrowcells =[] +# while lTexts[itext].getY() <= cut: +# lrowcells.append(lTexts[itext]) +# itext += 1 +# if lprevrowcells == []: +# pass +# else: +# # a new row: evaluate if this is better to create it or to merge ltext with current row +# # check coherence of new texts +# # assume columns! +# coherence = self.computeCoherenceScoreForRows(lrowcells) +# coherenceMerge = self.computeCoherenceScoreForRows(lrowcells+lprevrowcells) +# if prevRowCoherenceScore + coherence > coherenceMerge: +# cuthere +# else: +# merge +# + + + + def createRowsWithCuts(self,lYCuts,table,tableNode,bTagDoc=False): + """ + REF XML + """ + prevCut = None +# prevCut = table.getY() + + lYCuts.sort() + for index,cut in enumerate(lYCuts): + # first correspond to the table: no rpw + if prevCut is not None: + rowNode= etree.Element("ROW") + if bTagDoc: + tableNode.append(rowNode) + else: + tableNode.append(rowNode) + rowNode.set('y',str(prevCut)) + rowNode.set('height',str(cut - prevCut)) + rowNode.set('x',str(table.getX())) + rowNode.set('width',str(table.getWidth())) + rowNode.set('id',str(index-1)) + + prevCut= cut + #last + cut=table.getY2() + rowNode= etree.Element("ROW") + tableNode.append(rowNode) + rowNode.set('y',str(prevCut)) + rowNode.set('height',str(cut - prevCut)) + rowNode.set('x',str(table.getX())) + rowNode.set('width',str(table.getWidth())) + rowNode.set('id',str(index)) + + + def createRefCluster(self,doc): + """ + Ref: a row = set of textlines + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + + + for page in self.ODoc.getPages(): + pageNode = etree.Element('PAGE') + pageNode.set("number",page.getAttribute('number')) + pageNode.set("pagekey",os.path.basename(page.getAttribute('imageFilename'))) + pageNode.set("width",page.getAttribute('width')) + pageNode.set("height",page.getAttribute('height')) + + root.append(pageNode) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + dRows={} + tableNode = etree.Element('TABLE') + tableNode.set("x",table.getAttribute('x')) + tableNode.set("y",table.getAttribute('y')) + tableNode.set("width",table.getAttribute('width')) + tableNode.set("height",table.getAttribute('height')) + pageNode.append(tableNode) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + try:dRows[int(cell.getAttribute("row"))].extend(cell.getObjects()) + except KeyError:dRows[int(cell.getAttribute("row"))] = cell.getObjects() + + for rowid in sorted(dRows.keys()): + rowNode= etree.Element("ROW") + tableNode.append(rowNode) + for elt in dRows[rowid]: + txtNode = etree.Element("TEXT") + txtNode.set('id',elt.getAttribute('id')) + rowNode.append(txtNode) + + return refdoc + + def createRef(self,doc): + """ + create a ref file from the xml one + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + + + for page in self.ODoc.getPages(): + #imageFilename="..\col\30275\S_Freyung_021_0001.jpg" width="977.52" height="780.0"> + pageNode = etree.Element('PAGE') + pageNode.set("number",page.getAttribute('number')) + pageNode.set("pagekey",os.path.basename(page.getAttribute('imageFilename'))) + pageNode.set("width",str(page.getAttribute('width'))) + pageNode.set("height",str(page.getAttribute('height'))) + + root.append(pageNode) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + print (table) + dRows={} + tableNode = etree.Element('TABLE') + tableNode.set("x",str(table.getAttribute('x'))) + tableNode.set("y",str(table.getAttribute('y'))) + tableNode.set("width",str(table.getAttribute('width'))) + tableNode.set("height",str(table.getAttribute('height'))) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + print (cell) + try:dRows[int(cell.getAttribute("row"))].append(cell) + except KeyError:dRows[int(cell.getAttribute("row"))] = [cell] + lYcuts = [] + for rowid in sorted(dRows.keys()): +# print rowid, min(map(lambda x:x.getY(),dRows[rowid])) + lYcuts.append(min(list(map(lambda x:x.getY(),dRows[rowid])))) + if lYcuts != []: + pageNode.append(tableNode) + self.createRowsWithCuts(lYcuts,table,tableNode) + + return refdoc + + def createRefPerPage(self,doc): + """ + create a ref file from the xml one + + for DAS 2018 + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + + dRows={} + for page in self.ODoc.getPages(): + #imageFilename="..\col\30275\S_Freyung_021_0001.jpg" width="977.52" height="780.0"> + pageNode = etree.Element('PAGE') +# pageNode.set("number",page.getAttribute('number')) + #SINGLER PAGE pnum=1 + pageNode.set("number",'1') + + pageNode.set("imageFilename",page.getAttribute('imageFilename')) + pageNode.set("width",page.getAttribute('width')) + pageNode.set("height",page.getAttribute('height')) + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + root.append(pageNode) + + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + tableNode = etree.Element('TABLE') + tableNode.set("x",table.getAttribute('x')) + tableNode.set("y",table.getAttribute('y')) + tableNode.set("width",table.getAttribute('width')) + tableNode.set("height",table.getAttribute('height')) + pageNode.append(tableNode) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + try:dRows[int(cell.getAttribute("row"))].append(cell) + except KeyError:dRows[int(cell.getAttribute("row"))] = [cell] + + lYcuts = [] + for rowid in sorted(dRows.keys()): +# print rowid, min(map(lambda x:x.getY(),dRows[rowid])) + lYcuts.append(min(list(map(lambda x:x.getY(),dRows[rowid])))) + self.createRowsWithCuts(lYcuts,table,tableNode) + + + self.outputFileName = os.path.basename(page.getAttribute('imageFilename')[:-3]+'ref') +# print(self.outputFileName) + self.writeDom(refdoc, bIndent=True) + + return refdoc + + # print refdoc.serialize('utf-8', True) +# self.testCPOUM(0.5,refdoc.serialize('utf-8', True),refdoc.serialize('utf-8', True)) + + def createRefPartition(self,doc): + """ + Ref: a row = set of textlines + :param doc: dox xml + returns a doc (ref format): each column contains a set of ids (textlines ids) + """ + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + + + for page in self.ODoc.getPages(): + pageNode = etree.Element('PAGE') + pageNode.set("number",page.getAttribute('number')) + pageNode.set("pagekey",os.path.basename(page.getAttribute('imageFilename'))) + pageNode.set("width",str(page.getAttribute('width'))) + pageNode.set("height",str(page.getAttribute('height'))) + + root.append(pageNode) + lTables = page.getAllNamedObjects(XMLDSTABLEClass) + for table in lTables: + dCols={} + tableNode = etree.Element('TABLE') + tableNode.set("x",table.getAttribute('x')) + tableNode.set("y",table.getAttribute('y')) + tableNode.set("width",str(table.getAttribute('width'))) + tableNode.set("height",str(table.getAttribute('height'))) + pageNode.append(tableNode) + for cell in table.getAllNamedObjects(XMLDSTABLECELLClass): + try:dCols[int(cell.getAttribute("row"))].extend(cell.getObjects()) + except KeyError:dCols[int(cell.getAttribute("row"))] = cell.getObjects() + + for rowid in sorted(dCols.keys()): + rowNode= etree.Element("ROW") + tableNode.append(rowNode) + for elt in dCols[rowid]: + txtNode = etree.Element("TEXT") + txtNode.set('id',elt.getAttribute('id')) + rowNode.append(txtNode) + + return refdoc + + + def createRunPartition(self,doc): + """ + Ref: a row = set of textlines + :param doc: dox xml + returns a doc (ref format): each column contains a set of ids (textlines ids) + """ +# self.ODoc = doc #XMLDSDocument() +# self.ODoc.loadFromDom(doc,listPages = range(self.firstPage,self.lastPage+1)) + + + root=etree.Element("DOCUMENT") + refdoc=etree.ElementTree(root) + + + for page in self.ODoc.getPages(): + pageNode = etree.Element('PAGE') + pageNode.set("number",page.getAttribute('number')) + pageNode.set("pagekey",os.path.basename(page.getAttribute('imageFilename'))) + pageNode.set("width",str(page.getAttribute('width'))) + pageNode.set("height",str(page.getAttribute('height'))) + + root.append(pageNode) + tableNode = etree.Element('TABLE') + tableNode.set("x","0") + tableNode.set("y","0") + tableNode.set("width","0") + tableNode.set("height","0") + pageNode.append(tableNode) + + table = page.getAllNamedObjects(XMLDSTABLEClass)[0] + lRows = table.getRows() + for row in lRows: + cNode= etree.Element("ROW") + tableNode.append(cNode) + for elt in row.getAllNamedObjects(XMLDSTEXTClass): + txtNode= etree.Element("TEXT") + txtNode.set('id',elt.getAttribute('id')) + cNode.append(txtNode) + + return refdoc + +if __name__ == "__main__": + + + rdc = RowDetection() + #prepare for the parsing of the command line + rdc.createCommandLineParser() +# rdc.add_option("--coldir", dest="coldir", action="store", type="string", help="collection folder") + rdc.add_option("--docid", dest="docid", action="store", type="string", help="document id") + rdc.add_option("--dsconv", dest="dsconv", action="store_true", default=False, help="convert page format to DS") + rdc.add_option("--createref", dest="createref", action="store_true", default=False, help="create REF file for component") + rdc.add_option("--createrefC", dest="createrefCluster", action="store_true", default=False, help="create REF file for component (cluster of textlines)") + rdc.add_option("--evalC", dest="evalCluster", action="store_true", default=False, help="evaluation using clusters (of textlines)") + rdc.add_option("--cell", dest="bCellOnly", action="store_true", default=False, help="generate cell candidate from BIO (no row)") + rdc.add_option("--nocolumn", dest="bNoColumn", action="store_true", default=False, help="no existing table/colunm)") +# rdc.add_option("--raw", dest="bRaw", action="store_true", default=False, help="no existing table/colunm)") + + rdc.add_option("--YC", dest="YCut", action="store_true", default=False, help="use Ycut") + rdc.add_option("--BTAG", dest="BTAG", action="store", default='B',type="string", help="BTAG = B or S") + rdc.add_option("--STAG", dest="STAG", action="store", default='S',type="string", help="STAG = S or None") + + rdc.add_option("--thhighsupport", dest="thhighsupport", action="store", type="int", default=33,help="TH for high support", metavar="NN") + + rdc.add_option('-f',"--first", dest="first", action="store", type="int", help="first page to be processed") + rdc.add_option('-l',"--last", dest="last", action="store", type="int", help="last page to be processed") + + #parse the command line + dParams, args = rdc.parseCommandLine() + + #Now we are back to the normal programmatic mode, we set the component parameters + rdc.setParams(dParams) + doc = rdc.loadDom() + doc = rdc.run(doc) + if doc is not None and rdc.getOutputFileName() != '-': + rdc.writeDom(doc, bIndent=True) + diff --git a/TranskribusDU/tasks/DU_Table/tableWorkflow.py b/TranskribusDU/tasks/DU_Table/tableWorkflow.py new file mode 100644 index 0000000..6e42086 --- /dev/null +++ b/TranskribusDU/tasks/DU_Table/tableWorkflow.py @@ -0,0 +1,533 @@ +# -*- coding: utf-8 -*- +""" + + + IETableWorkflow.py + + Process a full collection with table analysis and IE extraction + Based on template + + H. Déjean + + copyright Xerox 2017 + READ project + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. +""" + + + + +import sys, os +import logging +import glob +from optparse import OptionParser + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))) + +try: + import TranskribusPyClient_version +except ImportError: + sys.path.append( os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )))) + , "TranskribusPyClient", "src" )) + sys.path.append( os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )))) + , "TranskribusPyClient", "src" )) + import TranskribusPyClient_version + +from TranskribusPyClient.client import TranskribusClient +from TranskribusCommands.TranskribusDU_transcriptUploader import TranskribusDUTranscriptUploader +from TranskribusCommands.Transkribus_downloader import TranskribusDownloader +from TranskribusCommands.do_analyzeLayoutNew import DoLAbatch +from TranskribusCommands.do_htrRnn import DoHtrRnn +from TranskribusCommands import sCOL + +import common.Component as Component +from TranskribusPyClient.common.trace import traceln, trace + +from xml_formats.PageXml import PageXml +from tasks.DU_ABPTable_T import DU_ABPTable_TypedCRF +from xml_formats.PageXml import MultiPageXml +from xml_formats.Page2DS import primaAnalysis +from xml_formats.DS2PageXml import DS2PageXMLConvertor +from tasks.rowDetection import RowDetection +from ObjectModel.xmlDSDocumentClass import XMLDSDocument + +class TableProcessing(Component.Component): + usage = "" + version = "v.01" + description = "description: table layout analysis based on template" + + + sCOL = "col" + sMPXMLExtension = ".mpxml" + + def __init__(self): + """ + Always call first the Component constructor. + """ + Component.Component.__init__(self, "TableProcessing", self.usage, self.version, self.description) + + self.colid = None + self.docid= None + + self.bFullCol = False + # generate MPXML using Ext + self.useExtForMPXML = False + + self.bRegenerateMPXML = False + + self.sRowModelName = None + self.sRowModelDir = None + + self.sHTRmodel = None + self.sDictName = None + + def setParams(self, dParams): + """ + Always call first the Component setParams + Here, we set our internal attribute according to a possibly specified value (otherwise it stays at its default value) + """ + Component.Component.setParams(self, dParams) + if "coldir" in dParams: + self.coldir = dParams["coldir"] + if "colid" in dParams: + self.colid = dParams["colid"] + if "colid" in dParams: + self.docid = dParams["docid"] + if "useExt" in dParams: + self.useExtForMPXML = dParams["useExt"] + + if 'regMPXML' in dParams: + self.bRegenerateMPXML=True + + if "rowmodelname" in dParams: + self.sRowModelName = dParams["rowmodelname"] + if "rowmodeldir" in dParams: + self.sRowModelDir = dParams["rowmodeldir"] + + if "htrmodel" in dParams: + self.sHTRmodel = dParams["htrmodel"] + if "dictname" in dParams: + self.sDictName = dParams["dictname"] + + + # Connection to Transkribus + self.myTrKCient = None + self.persist = False + self.loginInfo = False + if dParams.has_key("server"): + self.server = dParams["server"] + if dParams.has_key("persist"): + self.persist = dParams["persist"] + if dParams.has_key("login"): + self.loginInfo = dParams["login"] + + def login(self,trnskrbs_client, trace=None, traceln=None): + """ + deal with the complicated login variants... + -trace and traceln are optional print methods + return True or raises an exception + """ + DEBUG=True + bOk = False + if self.persist: + #try getting some persistent session token + if DEBUG and trace: trace(" ---login--- Try reusing persistent session ... ") + try: + bOk = trnskrbs_client.reusePersistentSession() + if DEBUG and traceln: traceln("OK!") + except: + if DEBUG and traceln: traceln("Failed") + + if not bOk: + if self.loginInfo: + login, pwd = self.loginInfo, self.pwd + else: + if trace: DEBUG and trace(" ---login--- no login provided, looking for stored credentials... ") + login, pwd = trnskrbs_client.getStoredCredentials(bAsk=False) + if DEBUG and traceln: traceln("OK") + + if DEBUG and traceln: trace(" ---login--- logging onto Transkribus as %s "%login) + trnskrbs_client.auth_login(login, pwd) + if DEBUG and traceln: traceln("OK") + bOk = True + + return bOk + + def downloadCollection(self,colid,destDir,docid,bNoImg=True,bForce=False): + """ + download colID + + replace destDir by '.' ? + """ + destDir="." +# options.server, proxies, loggingLevel=logging.WARN) + #download + downloader = TranskribusDownloader(self.myTrKCient.getServerUrl(),self.myTrKCient.getProxies()) + downloader.setSessionId(self.myTrKCient.getSessionId()) + traceln("- Downloading collection %s to folder %s"%(colid, os.path.abspath(destDir))) +# col_ts, colDir = downloader.downloadCollection(colid, destDir, bForce=options.bForce, bNoImage=options.bNoImage) + col_ts, colDir, ldocids, dFileListPerDoc = downloader.downloadCollection(colid, destDir, bForce = bForce, bNoImage=bNoImg,sDocId=docid) + traceln("- Done") + + with open(os.path.join(colDir, "config.txt"), "w") as fd: fd.write("server=%s\nforce=%s\nstrict=%s\n"%(self.server, True, False)) + + downloader.generateCollectionMultiPageXml(os.path.join(colDir, TableProcessing.sCOL),dFileListPerDoc,False) + + traceln('- Done, see in %s'%colDir) + + return ldocids + + def upLoadDocument(self,colid,coldir,docid,sNote="",sTranscripExt='.mpxml'): + """ + download colID + """ + +# options.server, proxies, loggingLevel=logging.WARN) + #download +# uploader = TranskribusTranscriptUploader(self.server,self.proxies) + uploader = TranskribusDUTranscriptUploader(self.myTrKCient.getServerUrl(),self.myTrKCient.getProxies()) + uploader.setSessionId(self.myTrKCient.getSessionId()) + traceln("- uploading document %s to collection %s" % (docid,colid)) + uploader.uploadDocumentTranscript(colid, docid, os.path.join(coldir,sCOL), sNote, 'NLE Table', sTranscripExt, iVerbose=False) + traceln("- Done") + return + + def applyLA_URO(self,colid,docid,nbpages): + """ + apply textline finder + """ + # do the job... +# if options.trp_doc: +# trpdoc = json.load(codecs.open(options.trp_doc, "rb",'utf-8')) +# docId,sPageDesc = doer.buildDescription(colId,options.docid,trpdoc) + + traceln('process %s pages...'%nbpages) + lretJobIDs = [] + for i in range(1,nbpages+1): + LA = DoLAbatch(self.myTrKCient.getServerUrl(),self.myTrKCient.getProxies()) + LA._trpMng.setSessionId(self.myTrKCient.getSessionId()) + LA.setSessionId(self.myTrKCient.getSessionId()) + _,sPageDesc = LA.buildDescription(colid,"%s/%s"%(docid,i)) + sPageDesc = LA.jsonToXMLDescription(sPageDesc) + _,lJobIDs = LA.run(colid, sPageDesc,"CITlabAdvancedLaJob",False) + traceln(lJobIDs) + lretJobIDs.extend(lJobIDs) + traceln("- LA running for page %d job:%s"%(i,lJobIDs)) + return lretJobIDs + + + def applyHTRForRegions(self,colid,docid,nbpages,modelname,dictionary): + """ + apply an htr model at region level + """ + + htrComp = DoHtrRnn(self.myTrKCient.getServerUrl(),self.myTrKCient.getProxies()) + htrComp._trpMng.setSessionId(self.myTrKCient.getSessionId()) + htrComp.setSessionId(self.myTrKCient.getSessionId()) + + _,sPageDesc = htrComp.buildDescription(colid,"%s/%s"%(docid,nbpages)) + + sPages= "1-%d"%(nbpages) + sModelID = None + # get modelID + lColModels = self.myTrKCient.listRnns(colid) + for model in lColModels: +# print model['htrId'], type(model['htrId']), modelname,type(modelname) + if str(model['htrId']) == str(modelname): + sModelID = model['htrId'] + traceln('model id = %s'%sModelID) + #some old? models do not have params field +# try: traceln("%s\t%s\t%s" % (model['htrId'],model['name'],model['params'])) +# except KeyError: traceln("%s\t%s\tno params" % (model['htrId'],model['name'])) + if sModelID == None: raise Exception, "no model ID found for %s" %(modelname) + ret = htrComp.htrRnnDecode(colid, sModelID, dictionary, docid, sPageDesc,bDictTemp=False) + traceln(ret) + return ret + + def applyHTR(self,colid,docid,nbpages,modelname,dictionary): + """ + apply HTR on docid + + htr id is needed: we have htrmodename + """ + htrComp = DoHtrRnn(self.myTrKCient.getServerUrl(),self.myTrKCient.getProxies()) + htrComp._trpMng.setSessionId(self.myTrKCient.getSessionId()) + htrComp.setSessionId(self.myTrKCient.getSessionId()) + + _,sPageDesc = htrComp.buildDescription(colid,"%s/%s"%(docid,nbpages)) + + sPages= "1-%d"%(nbpages) + sModelID = None + # get modelID + lColModels = self.myTrKCient.listRnns(colid) + for model in lColModels: +# print model['htrId'], type(model['htrId']), modelname,type(modelname) + if str(model['htrId']) == str(modelname): + sModelID = model['htrId'] + traceln('model id = %s'%sModelID) + #some old? models do not have params field +# try: traceln("%s\t%s\t%s" % (model['htrId'],model['name'],model['params'])) +# except KeyError: traceln("%s\t%s\tno params" % (model['htrId'],model['name'])) + if sModelID == None: raise Exception, "no model ID found for %s" %(modelname) + ret = htrComp.htrRnnDecode(colid, sModelID, dictionary, docid, sPageDesc,bDictTemp=False) + traceln(ret) + return ret + + + def extractFileNamesFromMPXML(self,mpxmldoc): + """ + to insure correct file order ! + + duplicated form performCVLLA.py + """ + xmlpath=os.path.abspath(os.path.join(self.coldir,sCOL,self.docid)) + + lNd = PageXml.getChildByName(mpxmldoc.getRootElement(), 'Page') +# for i in lNd:print i + return map(lambda x:"%s%s%s.xml"%(xmlpath,os.sep,x.prop('imageFilename')[:-4]), lNd) + + + def processDocument(self,coldir,colid,docid,dom=None): + """ + process a single document + + 1 python ../../src/xml_formats/PageXml.py trnskrbs_5400/col/17442 --ext=pxml + 2 python ../../src/tasks/performCVLLA.py --coldir=trnskrbs_5400/ --docid=17442 -i trnskrbs_5400/col/17442.mpxml --bl --regTL --form + 3 python ../../src/tasks/DU_ABPTable_T.py modelMultiType tableRow2 --run=trnskrbs_5400 + 4 python ../../src/xml_formats/Page2DS.py --pattern=trnskrbs_5400/col/17442_du.mpxml -o trnskrbs_5400/xml/17442.ds_xml --docid=17442 + 5 python src/IE_test.py -i trnskrbs_5400/xml/17442.ds_xml -o trnskrbs_5400/out/17442.ds_xml + 6 python ../../../TranskribusPyClient/src/TranskribusCommands/TranskribusDU_transcriptUploader.py --nodu trnskrbs_5400 5400 17442 + 7 python ../../../TranskribusPyClient/src/TranskribusCommands/do_htrRnn.py 5400 17442 + + wait + 8 python ../../../TranskribusPyClient/src/TranskribusCommands/Transkribus_downloader.py 5400 --force + #covnert to ds + 9 python ../../src/xml_formats/Page2DS.py --pattern=trnskrbs_5400/col/17442.mpxml -o trnskrbs_5400/xml/17442.ds_xml --docid=17442 + 10 python src/IE_test.py -i trnskrbs_5400/xml/17442.ds_xml -o trnskrbs_5400/out/17442.ds_xml --doie --usetemplate + + """ + + #create Transkribus client + self.myTrKCient = TranskribusClient(sServerUrl=self.server,proxies={},loggingLevel=logging.WARN) + #login + _ = self.login(self.myTrKCient,trace=trace, traceln=traceln) + +# self.downloadCollection(colid,coldir,docid,bNoImg=False,bForce=True) + + ## load dom + if dom is None: + self.inputFileName = os.path.abspath(os.path.join(coldir,TableProcessing.sCOL,docid+TableProcessing.sMPXMLExtension)) + mpxml_doc = self.loadDom() + nbPages = MultiPageXml.getNBPages(mpxml_doc) + else: + # load provided mpxml + mpxml_doc = dom + nbPages = MultiPageXml.getNBPages(mpxml_doc) + +# ### table registration: need to compute/select??? the template +# # perform LA separator, table registration, baseline with normalization +# #python ../../src/tasks/performCVLLA.py --coldir=trnskrbs_5400/ --docid=17442 -i trnskrbs_5400/col/17442.mpxml --bl --regTL --form +# tableregtool= LAProcessor() +# # latool.setParams(dParams) +# tableregtool.coldir = coldir +# tableregtool.docid = docid +# tableregtool.bTemplate, tableregtool.bSeparator , tableregtool.bBaseLine , tableregtool.bRegularTextLine = True,False,False,False +# # creates xml and a new mpxml +# mpxml_doc,nbPages = tableregtool.performLA(mpxml_doc) +# +# + +# self.upLoadDocument(colid, coldir,docid,sNote='NLE workflow;table reg done') + + lJobIDs = self.applyLA_URO(colid, docid, nbPages) + return + + bWait=True + assert lJobIDs != [] + jobid=lJobIDs[-1] + traceln("waiting for job %s"%jobid) + while bWait: + dInfo = self.myTrKCient.getJobStatus(jobid) + bWait = dInfo['state'] not in [ 'FINISHED', 'FAILED' ] + + + ## coldir??? + self.downloadCollection(colid,coldir,docid,bNoImg=True,bForce=True) + + ##STOP HERE FOR DAS newx testset: + return + + # tag text for BIES cell + #python ../../src/tasks/DU_ABPTable_T.py modelMultiType tableRow2 --run=trnskrbs_5400 + """ + needed : doer = DU_ABPTable_TypedCRF(sModelName, sModelDir, + """ + doer = DU_ABPTable_TypedCRF(self.sRowModelName, self.sRowModelDir) + doer.load() + ## needed predict at file level, and do not store dom, but return it + rowpath=os.path.join(coldir,"col") + BIESFiles = doer.predict([rowpath],docid) + BIESDom = self.loadDom(BIESFiles[0]) +# res= BIESDom.saveFormatFileEnc('test.mpxml', "UTF-8",True) + + # MPXML2DS + #python ../../src/xml_formats/Page2DS.py --pattern=trnskrbs_5400/col/17442_du.mpxml -o trnskrbs_5400/xml/17442.ds_xml --docid=17442 + dsconv = primaAnalysis() + DSBIESdoc = dsconv.convert2DS(BIESDom,self.docid) + + # create XMLDOC object + self.ODoc = XMLDSDocument() + self.ODoc.loadFromDom(DSBIESdoc) #,listPages = range(self.firstPage,self.lastPage+1)) + # create row + #python src/IE_test.py -i trnskrbs_5400/xml/17442.ds_xml -o trnskrbs_5400/out/17442.ds_xml + rdc = RowDetection() + rdc.findRowsInDoc(self.ODoc) + + + #python ../../src/xml_formats/DS2PageXml.py -i trnskrbs_5400/out/17442.ds_xml --multi + # DS2MPXML + DS2MPXML = DS2PageXMLConvertor() + lPageXml = DS2MPXML.run(self.ODoc.getDom()) + if lPageXml != []: +# if DS2MPXML.bMultiPages: + newDoc = MultiPageXml.makeMultiPageXmlMemory(map(lambda xy:xy[0],lPageXml)) + outputFileName = os.path.join(self.coldir, sCOL, self.docid+TableProcessing.sMPXMLExtension) + newDoc.write(outputFileName, xml_declaration=True,encoding="UTF-8",pretty_print=True) +# else: +# DS2MPXML.storePageXmlSetofFiles(lPageXml) + + return + + #upload + # python ../../../TranskribusPyClient/src/TranskribusCommands/TranskribusDU_transcriptUploader.py --nodu trnskrbs_5400 5400 17442 + self.upLoadDocument(colid, coldir,docid,sNote='NLE workflow;table row done') + + + ## apply HTR + ## how to deal with specific dictionaries? + + ## here need to know the ontology and the template + + nbPages=1 + jobid = self.applyHTR(colid,docid, nbPages,self.sHTRmodel,self.sDictName) + bWait=True + traceln("waiting for job %s"%jobid) + while bWait: + dInfo = self.myTrKCient.getJobStatus(jobid) + bWait = dInfo['state'] not in [ 'FINISHED', 'FAILED' ,'CANCELED'] + + + # download where??? + # python ../../../TranskribusPyClient/src/TranskribusCommands/Transkribus_downloader.py 5400 --force + # coldir is not right!! coldir must refer to the parent folder! + self.downloadCollection(colid,coldir,docid,bNoImg=True,bForce=True) + + #done!! + + # IE extr + ## not here: specific to a usecas + #python src/IE_test.py -i trnskrbs_5400/xml/17442.ds_xml -o trnskrbs_5400/out/17442.ds_xml --doie --usetemplate + + + def processCollection(self,coldir): + """ + process all files in a colelction + need mpxml files + """ + lsDocFilename = sorted(glob.iglob(os.path.join(coldir, "*"+TableProcessing.sMPXMLExtension))) + lDocId = [] + for sDocFilename in lsDocFilename: + sDocId = os.path.basename(sDocFilename)[:-len(TableProcessing.sMPXMLExtension)] + try: + docid = int(sDocId) + lDocId.append(docid) + except ValueError: + traceln("Warning: folder %s : %s invalid docid, IGNORING IT"%(self.coldir, sDocId)) + continue + + # process each document + for docid in lDocId: + traceln("Processing %s : %s "%(self.coldir, sDocId)) + self.processDocument(self.colid, docid) + traceln("\tProcessing done for %s "%(self.coldir, sDocId)) + + + def processParameters(self): + """ + what to do with the parameters provided by the command line + """ + if self.colid is None: + print('collection id missing!') + sys.exit(1) + + self.bFullCol = self.docid != None + + if self.bRegenerateMPXML and self.docid is not None: + l = glob.glob(os.path.join(self.coldir,sCOL,self.docid, "*.pxml")) + doc = MultiPageXml.makeMultiPageXml(l) + outputFileName = os.path.join(self.coldir, sCOL, self.docid+TableProcessing.sMPXMLExtension) + doc.write(outputFileName, xml_declaration=True,encoding="UTF-8",pretty_print=True) + return doc + return None + + def run(self): + """ + process at collection level or document level + """ + newMPXML = self.processParameters() + if self.bFullCol is None: + self.processCollection(self.colid) + else: + self.processDocument(self.coldir,self.colid, self.docid,newMPXML) + +if __name__ == "__main__": + + + ## parser for cloud connection + parser = OptionParser() + + + tableprocessing = TableProcessing() + tableprocessing.createCommandLineParser() + + tableprocessing.parser.add_option("-s", "--server" , dest='server', action="store", type="string", default="https://transkribus.eu/TrpServer", help="Transkribus server URL") + + tableprocessing.parser.add_option("-l", "--login" , dest='login' , action="store", type="string", help="Transkribus login (consider storing your credentials in 'transkribus_credentials.py')") + tableprocessing.parser.add_option("-p", "--pwd" , dest='pwd' , action="store", type="string", help="Transkribus password") + + tableprocessing.parser.add_option("--persist" , dest='persist', action="store_true", help="Try using an existing persistent session, or log-in and persists the session.") + + tableprocessing.parser.add_option("--https_proxy" , dest='https_proxy' , action="store", type="string", help="proxy, e.g. http://cornillon:8000") + + tableprocessing.parser.add_option("--pxml", dest="regMPXML", action="store_true", help="recreate MPXML frol PXML") + + tableprocessing.parser.add_option("--coldir", dest="coldir", action="store", type="string", help="collection folder") + tableprocessing.parser.add_option("--colid", dest="colid", action="store", type="string", help="collection id") + + tableprocessing.parser.add_option("--docid", dest="docid", action="store", type="string", help="document id") + tableprocessing.parser.add_option("--useExt", dest="useExt", action="store", type="string", help="generate mpxml using page file .ext") + + ## ROW + tableprocessing.parser.add_option("--rowmodel", dest="rowmodelname", action="store", type="string", help="row model name") + tableprocessing.parser.add_option("--rowmodeldir", dest="rowmodeldir", action="store", type="string", help="row model directory") + ## HTR + tableprocessing.parser.add_option("--htrid", dest="htrmodel", action="store", type="string", help="HTR mode") + tableprocessing.parser.add_option("--dictname", dest="dictname", action="store", type="string", help="dictionary for HTR") + +# tableprocessing.add_option('-f',"--first", dest="first", action="store", type="int", help="first page to be processed") +# tableprocessing.add_option('-l',"--last", dest="last", action="store", type="int", help="last page to be processed") + + #parse the command line + dParams, args = tableprocessing.parseCommandLine() + #Now we are back to the normal programmatic mode, we set the componenet parameters + tableprocessing.setParams(dParams) + + tableprocessing.run() + + \ No newline at end of file diff --git a/TranskribusDU/tasks/DU_Task.py b/TranskribusDU/tasks/DU_Task.py index 2f60c9e..285182c 100644 --- a/TranskribusDU/tasks/DU_Task.py +++ b/TranskribusDU/tasks/DU_Task.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016, 2017 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -27,7 +16,9 @@ import sys, os, glob, datetime import json from importlib import import_module -import random +from io import StringIO +import traceback +import lxml.etree as etree import numpy as np @@ -40,15 +31,18 @@ except ImportError: sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) import TranskribusDU_version +TranskribusDU_version from common.trace import trace, traceln -from common.chrono import chronoOn, chronoOff +from common.chrono import chronoOn, chronoOff, pretty_time_delta from common.TestReport import TestReportConfusion from xml_formats.PageXml import MultiPageXml -from graph.GraphModel import GraphModel, GraphModelException +from graph.GraphModel import GraphModel, GraphModelException, GraphModelNoEdgeException +from graph.Graph_JsonOCR import Graph_JsonOCR +from graph.Graph_DOM import Graph_DOM import graph.FeatureDefinition -from tasks import _checkFindColDir, _exit +from tasks import _checkFindColDir class DU_Task: @@ -89,7 +83,7 @@ class DU_Task: cFeatureDefinition = None # FeatureDefinition_PageXml_StandardOnes #I keep this for backward compa - sMetadata_Creator = "NLE Document Understanding" + sMetadata_Creator = "NLE Document Understanding: DU_Task" sMetadata_Comments = "" #dGridSearch_LR_conf = {'C':[0.1, 0.5, 1.0, 2.0] } #Grid search parameters for LR baseline method training @@ -101,6 +95,8 @@ class DU_Task: iNbNodeType = 1 # as of today, only CRF can do multitype + bConjugate = False + def configureGraphClass(self, configuredClass=None): """ class method to set the graph class ONCE (subsequent calls are ignored) @@ -114,9 +110,11 @@ class method to set the graph class ONCE (subsequent calls are ignored) assert configuredClass is not None, "getConfiguredGraphClass returned None" self.cGraphClass = configuredClass + self.bConjugate = configuredClass.bConjugate assert self.cGraphClass is not None traceln("SETUP: Graph class is %s (graph mode %d)" % (self.cGraphClass, self.cGraphClass.getGraphMode())) + traceln("SETUP: Input format is '%s'" % (self.cGraphClass.getDocInputFormat())) return self.cGraphClass @@ -144,10 +142,6 @@ def __init__(self, sModelName, sModelDir self._mdl = None - # for the conjugate mode - self.bConjugate = False - self.nbEdgeClass = None - self._lBaselineModel = [] self.bVerbose = True @@ -173,10 +167,16 @@ def getVersion(cls): def standardDo(self, options): """ - do whatever is reuested by an option from the parsed command line + do whatever is requested by an option from the parsed command line return None """ + if bool(options.iServer): + self.load() + # run in server mode! + self.serve_forever(options.iServer, options.bServerDebug, options=options) + return + if options.rm: self.rm() return @@ -252,7 +252,7 @@ def standardDo(self, options): # lsOutputFilename = self.runForExternalMLMethod(lRun, options.storeX, options.applyY, options.bRevertEdges) # else: self.load() - lsOutputFilename = self.predict(lRun, bGraph=options.bGraph) + lsOutputFilename = self.predict(lRun, bGraph=options.bGraph,bOutXML=options.bOutXML) traceln("Done, see in:\n %s"%lsOutputFilename) else: @@ -269,7 +269,113 @@ def __del__(self): del self.cFeatureDefinition del self.cModelClass + #--- SERVER MODE --------------------------------------------------------- + def serve_forever(self, iPort, bDebug=False, options={}): + self.sTime_start = datetime.datetime.now().isoformat() + self.sTime_load = self.sTime_start + + import socket + sURI = "http://%s:%d" % (socket.gethostbyaddr(socket.gethostname())[0], iPort) + sDescr = """ +- home page for humans: %s +- POST or GET on %s/predict with argument xml=... +""" % ( sURI, sURI) + traceln("SERVER MODE") + traceln(sDescr) + + from flask import Flask + from flask import request, abort + from flask import render_template_string #, render_template + from flask import redirect, url_for #, send_from_directory, send_file + + + # Create Flask app load app.config + app = Flask(self.__class__.__name__) + + @app.route('/') + def home_page(): + # String-based templates + return render_template_string(""" +DU_Task server +
    +
  • Server start time : {{ start_time }}
    +
  • Model load time   : {{ load_time }}
    +
  • Model : ({{ model_type }}) {{ model_spec }} +
+

+reload the model +

Provide some {{ input_format }} data and get PageXml output: +
+ + + +

+This server runs with those options: {{ sOptions }} +""" + , model_type=self.__class__.__name__ + , model_spec=os.path.abspath(self.getModel().getModelFilename()) + , input_format=self.getGraphClass().getDocInputFormat() + , start_time=self.sTime_start + , load_time=self.sTime_load + , sOptions=str(options)) + traceln("SERVER ENDING. BYE") + + @app.route('/predict', methods = ['POST']) + def predict(): + try: + sData = request.form['data'] + if sData.startswith("")+2:] + + doc, lg = self._predict_file(self.getGraphClass(), [], StringIO(sData), bGraph=options.bGraph) + + # if nothing to do, the method returns None... + if doc is None: + # doc = etree.parse(StringIO(sXml)) + return sData + else: + if not(isinstance(doc, etree._ElementTree)): + traceln(" converting to PageXml...") + doc = Graph_DOM.exportToDom(lg) + return etree.tostring(doc.getroot(), encoding='UTF-8', xml_declaration=False) + + except Exception as e: + traceln("----- predict exception -------------------------") + traceln(traceback.format_exc()) + traceln("--------------------------------------------------") + abort(418, repr(e)) + + @app.route('/reload') + def reload(): + """ + Force to reload the model + """ + self.load(bForce=True) + self.sTime_load = datetime.datetime.now().isoformat() + return redirect(url_for('home_page')) + + # RUN THE SERVER !! + # CAUTION: TensorFlow incompatible with debug=True (double load => GPU issue) + app.run(host='0.0.0.0', port=iPort, debug=bDebug) + + @app.route('/stop') + def stop(): + """ + Force to exit + """ + traceln("Exiting") + sys.exit(0) + # RUN THE SERVER !! + # CAUTION: TensorFlow incompatible with debug=True (double load => GPU issue) + app.run(host='0.0.0.0', port=iPort, debug=bDebug) + + return + + #--- CONFIGURATION setters -------------------------------------------------------------------- def getGraphClass(self): return self.cGraphClass @@ -318,40 +424,15 @@ def getNbClass(self): #OK """ return self.nbClass - def setConjugateMode(self - , lEdgeLabel = None # list of labels (list of strings, or of int) - , funEdgeLabel_get = None # to compute the edge labels - , funEdgeLabel_set = None # to use the predicted edge labels - ): + def setXmlFilenamePattern(self, sExt): """ - to learn and predict on the conjugate graph instead of the usual graph. - 1 - The usual graph is created as always - 2 - the function is called on each edge to compute the edge label - 3 - the conjugate is created and used for learning or predicting - 4 - the function is called on each edge to exploit the edge predicted label - - - The prototype of the functions are: - funEdgeLabel_get(primal_graph, primal_X, primal_Y) - -> dual_Y - funEdgeLabel_set(primal_graph, nd_node, edge_matrix, dual_Y) - -> None - - the dual_Y has as many rows as the primal Edge array, and in the same order - this order also corresponds to the lEdge attribute of the graph object - - In case the graph has some pre-established settings, you can omit the parameters. - """ - self.bConjugate = True - self.cModelClass.setConjugateMode() - self.cGraphClass.setConjugateMode(lEdgeLabel - , funEdgeLabel_get - , funEdgeLabel_set) - self.nbEdgeLabel = len(self.cGraphClass.getEdgeLabelNameList()) - - return self.bConjugate + Set the expected file extension of the input data + """ + assert sExt, "Empty extension not allowed" + if not sExt.startswith("."): sExt = "." + sExt + self.sXmlFilenamePattern = "*" + sExt + - #---------------------------------------------------------------------------------------------------------- def setBaselineList(self, lMdl): """ @@ -432,7 +513,7 @@ def train_save_test(self, lsTrnColDir, lsTstColDir, lsVldColDir, bWarm=False, bP return a test report object """ self.traceln("-"*50) - self.traceln("Model files of '%s' in folder '%s'"%(self.sModelName, self.sModelDir)) + self.traceln("Model files of '%s' in folder '%s'"%(self.sModelName, os.path.abspath(self.sModelDir))) self.traceln("Training with collection(s):", lsTrnColDir) self.traceln("Testing with collection(s):", lsTstColDir) if lsVldColDir: self.traceln("Validating with collection(s):", lsVldColDir) @@ -473,33 +554,23 @@ def test(self, lsTstColDir): if lPageConstraint: for dat in lPageConstraint: self.traceln("\t\t%s"%str(dat)) - if True: - oReport = self._mdl.testFiles(lFilename_tst, lambda fn: DU_GraphClass.loadGraphs(self.cGraphClass, [fn], bDetach=True, bLabelled=True, iVerbose=1) - , self.getBaselineList() != []) - else: - self.traceln("- loading test graphs") - lGraph_tst = DU_GraphClass.loadGraphs(self.cGraphClass, lFilename_tst, bDetach=True, bLabelled=True, iVerbose=1) - if self.bConjugate: - for _g in lGraph_tst: _g.computeEdgeLabels() - - self.traceln(" %d graphs loaded"%len(lGraph_tst)) - oReport = self._mdl.test(lGraph_tst) + oReport = self._mdl.testFiles(lFilename_tst, lambda fn: DU_GraphClass.loadGraphs(self.cGraphClass, [fn], bDetach=True, bLabelled=True, iVerbose=1) + , self.getBaselineList() != []) return oReport - def predict(self, lsColDir, docid=None, bGraph=False): + def predict(self, lsColDir, docid=None, bGraph=False, bOutXML=True): """ Return the list of produced files """ - self.traceln("-"*50) - self.traceln("Predicting for collection(s):", lsColDir) - self.traceln("-"*50) - if not self._mdl: raise Exception("The model must be loaded beforehand!") #list files if docid is None: + self.traceln("-"*50) + self.traceln("Predicting for collection(s):", lsColDir, " (%s)" % self.sXmlFilenamePattern) + self.traceln("-"*50) _ , lFilename = self.listMaxTimestampFile(lsColDir, self.sXmlFilenamePattern) # predict for this file only else: @@ -517,123 +588,84 @@ def predict(self, lsColDir, docid=None, bGraph=False): chronoOn("predict") self.traceln("- loading collection as graphs, and processing each in turn. (%d files)"%len(lFilename)) - du_postfix = "_du"+MultiPageXml.sEXT lsOutputFilename = [] for sFilename in lFilename: - if sFilename.endswith(du_postfix): continue #:) - chronoOn("predict_1") - lg = DU_GraphClass.loadGraphs(self.cGraphClass, [sFilename], bDetach=False, bLabelled=False, iVerbose=1) - #normally, we get one graph per file, but in case we load one graph per page, for instance, we have a list - if lg: - for i, g in enumerate(lg): - doc = g.doc - if lPageConstraint: - self.traceln("\t- prediction with logical constraints: %s"%sFilename) + if DU_GraphClass.isOutputFilename(sFilename): + traceln(" - ignoring '%s' because of its extension" % sFilename) + continue + + doc, lg = self._predict_file(DU_GraphClass, lPageConstraint, sFilename, bGraph=bGraph) + + if doc is None: + self.traceln("\t- no prediction to do for: %s"%sFilename) + else: + sCreator = self.sMetadata_Creator + " " + self.getVersion() + sComment = self.sMetadata_Comments \ + if bool(self.sMetadata_Comments) \ + else "Model: %s %s (%s)" % ( + self.sModelName + , self._mdl.__class__.__name__ + , os.path.abspath(self.sModelDir)) + # which output format + if bOutXML: + if DU_GraphClass == Graph_DOM: + traceln(" ignoring export-to-DOM (already DOM output)") + pass else: - self.traceln("\t- prediction : %s"%sFilename) + doc = Graph_DOM.exportToDom(lg) + sDUFilename = Graph_DOM.saveDoc(sFilename, doc, lg, sCreator, sComment) + traceln(" - exported as XML to ", sDUFilename) + else: + sDUFilename = DU_GraphClass.saveDoc(sFilename, doc, lg + , sCreator=sCreator + , sComment=sComment) - if self.bConjugate: - Y = self._mdl.predict(g, bProba=True) - g.exploitEdgeLabels(Y) - else: - Y = self._mdl.predict(g) - g.setDomLabels(Y) - if bGraph: g.addEdgeToDOM(Y) - del Y - - MultiPageXml.setMetadata(doc, None, self.sMetadata_Creator, self.sMetadata_Comments) - sDUFilename = sFilename[:-len(MultiPageXml.sEXT)] +du_postfix - doc.write(sDUFilename, - xml_declaration=True, - encoding="utf-8", - pretty_print=True - #compression=0, #0 to 9 - ) del doc del lg - - lsOutputFilename.append(sDUFilename) - else: - self.traceln("\t- no prediction to do for: %s"%sFilename) - - self.traceln("\t done [%.2fs]"%chronoOff("predict_1")) + lsOutputFilename.append(sDUFilename) self.traceln(" done [%.2fs]"%chronoOff("predict")) - return lsOutputFilename - def runForExternalMLMethod(self, lsColDir, storeX, applyY, bRevertEdges=False): + def _predict_file(self, DU_GraphClass, lPageConstraint, sFilename, bGraph=False): """ - HACK: to test new ML methods, not yet integrated in our SW: storeX=None, storeXY=None, applyY=None - Return the list of produced files + Return the doc (a DOM?, a JSON?, another ?), the list of graphs + Note: the doc can be None is no graph """ + chronoOn("predict_1") + doc = None + lg = DU_GraphClass.loadGraphs(self.cGraphClass, [sFilename], bDetach=False, bLabelled=False, iVerbose=1) - self.traceln("-"*50) - if storeX: traceln("Loading data and storing [X] (1 X per graph)") - if applyY: traceln("Loading data, loading Y, labelling data, storing annotated data") - self.traceln("-"*50) + #normally, we get one graph per file, but in case we load one graph per page, for instance, we have a list + for i, g in enumerate(lg): + if not g.lNode: continue # no node... + doc = g.doc + if lPageConstraint: + #self.traceln("\t- prediction with logical constraints: %s"%sFilename) + self.traceln("\t- page constraints IGNORED!!") + self.traceln("\t- prediction : %s"%sFilename) - if storeX and applyY: - raise ValueError("Either store X or applyY, not both") + self._predict_graph(g, lPageConstraint=lPageConstraint, bGraph=bGraph) + self.traceln("\t done [%.2fs]"%chronoOff("predict_1")) + return doc, lg - if not self._mdl: raise Exception("The model must be loaded beforehand!") - - #list files - _ , lFilename = self.listMaxTimestampFile(lsColDir, self.sXmlFilenamePattern) - - DU_GraphClass = self.getGraphClass() - - lPageConstraint = DU_GraphClass.getPageConstraint() - if lPageConstraint: - for dat in lPageConstraint: self.traceln("\t\t%s"%str(dat)) - - if applyY: - self.traceln("LOADING [Y] from %s"%applyY) - lY = self._mdl.gzip_cPickle_load(applyY) - if storeX: lX = [] - - chronoOn("predict") - self.traceln("- loading collection as graphs, and processing each in turn. (%d files)"%len(lFilename)) - du_postfix = "_du"+MultiPageXml.sEXT - lsOutputFilename = [] - for sFilename in lFilename: - if sFilename.endswith(du_postfix): continue #:) - chronoOn("predict_1") - lg = DU_GraphClass.loadGraphs(self.cGraphClass, [sFilename], bDetach=False, bLabelled=False, iVerbose=1) - #normally, we get one graph per file, but in case we load one graph per page, for instance, we have a list - if lg: - for g in lg: - if self.bConjugate: g.computeEdgeLabels() - doc = g.doc - if bRevertEdges: g.revertEdges() #revert the directions of the edges - if lPageConstraint: - self.traceln("\t- prediction with logical constraints: %s"%sFilename) - else: - self.traceln("\t- prediction : %s"%sFilename) - if storeX: - [X] = self._mdl.get_lX([g]) - lX.append(X) - else: - Y = lY.pop(0) - g.setDomLabels(Y) - del lg - - if applyY: - MultiPageXml.setMetadata(doc, None, self.sMetadata_Creator, self.sMetadata_Comments) - sDUFilename = sFilename[:-len(MultiPageXml.sEXT)]+du_postfix - doc.saveFormatFileEnc(sDUFilename, "utf-8", True) #True to indent the XML - doc.freeDoc() - lsOutputFilename.append(sDUFilename) - else: - self.traceln("\t- no prediction to do for: %s"%sFilename) - - self.traceln("\t done [%.2fs]"%chronoOff("predict_1")) - self.traceln(" done [%.2fs]"%chronoOff("predict")) - - if storeX: - self.traceln("STORING [X] in %s"%storeX) - self._mdl.gzip_cPickle_dump(storeX, lX) - - return lsOutputFilename + def _predict_graph(self, g, lPageConstraint=None, bGraph=False): + """ + predict for a graph + side effect on the graph g + return the graph + """ + try: + Y = self._mdl.predict(g, bProba=g.bConjugate) + g.setDocLabels(Y) + if bGraph and not Y is None: + if g.bConjugate: + g.addEdgeToDoc(Y) + else: + g.addEdgeToDoc() + del Y + except GraphModelNoEdgeException: + traceln("*** ERROR *** cannot predict due to absence of edge in graph") + return g def checkLabelCoverage(self, lY): #check that all classes are represented in the dataset @@ -882,7 +914,7 @@ def _train_save_test(self, sModelName, bWarm, lFilename_trn, ts_trn, lFilename_t #for this check, we load the Y once... if self.bConjugate: - mdl.setNbClass(self.nbEdgeLabel) + mdl.setNbClass(len(self.cGraphClass.getEdgeLabelNameList())) for _g in lGraph_trn: _g.computeEdgeLabels() for _g in lGraph_vld: _g.computeEdgeLabels() else: @@ -921,7 +953,8 @@ def _train_save_test(self, sModelName, bWarm, lFilename_trn, ts_trn, lFilename_t chronoOn("MdlTrn") mdl.train(lGraph_trn, lGraph_vld, True, ts_trn, verbose=1 if self.bVerbose else 0) mdl.save() - self.traceln(" done [%.1fs]"%chronoOff("MdlTrn")) + tTrn = chronoOff("MdlTrn") + self.traceln(" training done [%.1f s] (%s)" % (tTrn, pretty_time_delta(tTrn))) # OK!! self._mdl = mdl @@ -966,9 +999,7 @@ def listMaxTimestampFile(cls, lsDir, sPattern, bIgnoreDUFiles=True): listMaxTimestampFile = classmethod(listMaxTimestampFile) -# ------------------------------------------------------------------------------------------------------------------------------ - - +# ----------------------------------------------------------------------------------------------------------------------------- if __name__ == "__main__": usage, parser = DU_Task.getStandardOptionsParser(sys.argv[0]) diff --git a/TranskribusDU/tasks/DU_Task_Factory.py b/TranskribusDU/tasks/DU_Task_Factory.py index c34491e..02fbed7 100644 --- a/TranskribusDU/tasks/DU_Task_Factory.py +++ b/TranskribusDU/tasks/DU_Task_Factory.py @@ -5,18 +5,7 @@ Copyright NAVER(C) 2019 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -38,15 +27,16 @@ from common.trace import traceln from graph.Graph import Graph from tasks.DU_CRF_Task import DU_CRF_Task -from tasks.DU_ECN_Task import DU_ECN_Task +from tasks.DU_ECN_Task import DU_ECN_Task, DU_Ensemble_ECN_Task from tasks.DU_GAT_Task import DU_GAT_Task + class DU_Task_Factory: VERSION = "Factory_19" version = None # dynamically computed - l_CHILDREN_CLASS = [DU_CRF_Task, DU_ECN_Task, DU_GAT_Task] + l_CHILDREN_CLASS = [DU_CRF_Task, DU_ECN_Task, DU_Ensemble_ECN_Task, DU_GAT_Task] # faster load for debug... l_CHILDREN_CLASS = [DU_CRF_Task] @classmethod @@ -55,6 +45,7 @@ def getStandardOptionsParser(cls, sys_argv0=None): or for a cross-validation [--fold-init ] [--fold-run [-w]] [--fold-finish] [--fold ]+ [--pkl] [--g1|--g2] +[--server ] For the named MODEL using the given FOLDER for storage: --rm : remove all model data from the folder @@ -77,6 +68,9 @@ def getStandardOptionsParser(cls, sys_argv0=None): --graph : store the graph in the output XML --g1 : default mode (historical): edges created only to closest overlapping block (downward and rightward) --g2 : implements the line-of-sight edges (when in line of sight, then link by an edge) + --server port : run in server mode, offering a predict method + --server-debug: run the server in debug + --outxml port : output PageXML files, whatever th einput format is. """%sys_argv0 #prepare for the parsing of the command line @@ -116,6 +110,14 @@ def getStandardOptionsParser(cls, sys_argv0=None): , help="default mode (historical): edges created only to closest overlapping block (downward and rightward)") parser.add_option("--g2", dest='bG2', action="store_true" , help="implements the line-of-sight edges (when in line of sight, then link the nodes by an edge)") + parser.add_option("--ext", dest='sExt', action="store", type="string" + , help="Expected extension of the data files, e.g. '.pxml'") + parser.add_option("--server", dest='iServer', action="store", type="int" + , help="run in server mode, offering a predict method, for the given model") + parser.add_option("--server_debug", dest='bServerDebug', action="store_true" + , help="run the server in debug mode (incompatible with TensorFLow)") + parser.add_option("--outxml", dest='bOutXML', action="store_true" + , help="output XML files, whatever the input format is.") # consolidate... @@ -139,13 +141,10 @@ def getVersion(cls): @classmethod def getDoer(cls, sModelDir, sModelName , options = None - , bCRF = None - , bECN = None - , bGAT = None , fun_getConfiguredGraphClass = None , sComment = None , cFeatureDefinition = None - , dFeatureConfig = {} + , dFeatureConfig = {} ): """ Create the requested doer object @@ -160,20 +159,25 @@ def getDoer(cls, sModelDir, sModelName if options.bG2: iGraphMode = 2 Graph.setGraphMode(iGraphMode) - bCRF = bCRF or (not(options is None) and options.bCRF) - bECN = bECN or (not(options is None) and options.bECN) - bGAT = bGAT or (not(options is None) and options.bGAT) - - assert (bCRF or bECN or bGAT) , "You must specify one learning method." - assert [bCRF, bECN, bGAT].count(True) == 1 , "You must specify only one learning method." +# bCRF = bCRF or (not(options is None) and options.bCRF) +# bECN = bECN or (not(options is None) and options.bECN) +# bECNEnsemble = bECNEnsemble or (not(options is None) and options.bECN) +# bGAT = bGAT or (not(options is None) and options.bGAT) + + assert (options.bCRF + or options.bECN or options.bECNEnsemble + or options.bGAT) , "You must specify one learning method." + assert [options.bCRF, options.bECN, options.bECNEnsemble, options.bGAT].count(True) == 1 , "You must specify only one learning method." - if bECN: + if options.bECN: c = DU_ECN_Task - elif bCRF: + elif options.bECNEnsemble: + c = DU_Ensemble_ECN_Task + elif options.bCRF: c = DU_CRF_Task - elif bGAT: + elif options.bGAT: c = DU_GAT_Task - + c.getConfiguredGraphClass = fun_getConfiguredGraphClass doer = c(sModelName, sModelDir @@ -181,6 +185,9 @@ def getDoer(cls, sModelDir, sModelName , cFeatureDefinition = cFeatureDefinition , dFeatureConfig = dFeatureConfig) + if options.sExt: + doer.setXmlFilenamePattern(options.sExt) + if options.seed is None: random.seed() traceln("SETUP: Randomizer initialized automatically") diff --git a/TranskribusDU/tasks/DU_Task_Features.py b/TranskribusDU/tasks/DU_Task_Features.py index 1ad3ef0..9f65699 100644 --- a/TranskribusDU/tasks/DU_Task_Features.py +++ b/TranskribusDU/tasks/DU_Task_Features.py @@ -5,18 +5,7 @@ Copyright NAVER(C) 2019 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -24,6 +13,7 @@ under grant agreement No 674943. """ +from sklearn.preprocessing.data import QuantileTransformer from graph.Edge import HorizontalEdge, VerticalEdge @@ -39,6 +29,14 @@ from graph.FeatureDefinition_Standard import EdgeClassShifter from graph.Transformer import Pipeline, FeatureUnion +from graph.pkg_GraphBinaryConjugateSegmenter.PageXmlSeparatorRegion import Separator_boolean, Separator_num + + +# EDGES +# which types of edge can we get?? +# It depends on the type of graph!! +lEdgeClass = [HorizontalEdge, VerticalEdge] + class Features_June19_Simple(FeatureDefinition): """ @@ -50,24 +48,29 @@ class Features_June19_Simple(FeatureDefinition): n_QUANTILES = 16 bShiftEdgeByClass = False - + bSeparator = False + def __init__(self): FeatureDefinition.__init__(self) # NODES - node_transformer = FeatureUnion([ \ + self.lNodeFeature = [ ("geometry" , Node_Geometry()) # one can set nQuantile=... - ]) + ] + node_transformer = FeatureUnion(self.lNodeFeature) # EDGES - # which types of edge can we get?? - # It depends on the type of graph!! - lEdgeClass = [HorizontalEdge, VerticalEdge] # standard set of features, including a constant 1 for CRF - edge_transformer = FeatureUnion([ \ + self.lEdgeFeature = [ ('1hot' , Edge_Type_1Hot(lEdgeClass=lEdgeClass)) # Edge class 1 hot encoded (PUT IT FIRST) , ('geom' , Edge_Geometry()) # one can set nQuantile=... - ]) + ] + if self.bSeparator: + self.lEdgeFeature = self.lEdgeFeature + [ + ('sprtr_bool', Separator_boolean()) + , ('sprtr_num' , Separator_num()) + ] + edge_transformer = FeatureUnion(self.lEdgeFeature) # OPTIONNALLY, you can have one range of features per type of edge. # the 1-hot encoding must be the first part of the union and it will determine @@ -101,12 +104,13 @@ class Features_June19_Full(FeatureDefinition): n_QUANTILES = 16 bShiftEdgeByClass = False + bSeparator = False def __init__(self): FeatureDefinition.__init__(self) # NODES - node_transformer = FeatureUnion([ \ + self.lNodeFeature = [ \ ("geometry" , Node_Geometry()) # one can set nQuantile=... , ("neighbor_count" , Node_Neighbour_Count()) # one can set nQuantile=... , ("text" , Node_Text_NGram( 'char' # character n-grams @@ -114,14 +118,15 @@ def __init__(self): , (2,3) # N , False # lowercase?)) )) - ]) - + ] + node_transformer = FeatureUnion(self.lNodeFeature) + # EDGES # which types of edge can we get?? # It depends on the type of graph!! lEdgeClass = [HorizontalEdge, VerticalEdge] # standard set of features, including a constant 1 for CRF - fu = FeatureUnion([ \ + self.lEdgeFeature = [ \ ('1hot' , Edge_Type_1Hot(lEdgeClass=lEdgeClass)) # Edge class 1 hot encoded (PUT IT FIRST) , ('1' , Edge_1()) # optional constant 1 for CRF , ('geom' , Edge_Geometry()) # one can set nQuantile=... @@ -135,7 +140,13 @@ def __init__(self): , (2,3) # N , False # lowercase?)) )) - ]) + ] + if self.bSeparator: + self.lEdgeFeature = self.lEdgeFeature + [ + ('sprtr_bool', Separator_boolean()) + , ('sprtr_num' , Separator_num()) + ] + fu = FeatureUnion(self.lEdgeFeature) # you can use directly this union of features! edge_transformer = fu @@ -161,3 +172,27 @@ class Features_June19_Full_Shift(Features_June19_Full): """ bShiftEdgeByClass = True +# --- Separator ------------------------------------------------------ +class Features_June19_Simple_Separator(Features_June19_Simple): + """ + Same as Features_June19_Simple, with additional features on edges + """ + bSeparator = True + + +class Features_June19_Full_Separator(Features_June19_Full): + """ + Same as Features_June19_Full, with additional features on edges + """ + bSeparator = True + + +# --- Separator Shifted ------------------------------------------------------ +class Features_June19_Simple_Separator_Shift(Features_June19_Simple_Separator + , Features_June19_Simple_Shift): + pass + + +class Features_June19_Full_Separator_Shift(Features_June19_Full_Separator + , Features_June19_Full_Shift): + pass diff --git a/TranskribusDU/tasks/DU_analyze_collection.py b/TranskribusDU/tasks/DU_analyze_collection.py new file mode 100644 index 0000000..d9ba20b --- /dev/null +++ b/TranskribusDU/tasks/DU_analyze_collection.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- + +""" + Utility to compute statistics regarding a PageXml collection. + + How many document? pages? objects? labels? + + The raw result is stored as a pikle file in a CSV file. (in the future version!!!) + The statistics are reported on stdout. + + Copyright Xerox(C) 2017 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os, collections, pickle, glob +from lxml import etree +import re +import gc +from optparse import OptionParser + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from xml_formats.PageXml import PageXml + +# =============================================================================================================== +#DEFINING THE CLASS OF GRAPH WE USE + +# =============================================================================================================== + +class DoubleHistogram: + """ + Double keyed histogram + """ + def __init__(self, name): + self.name = name + self.dCnt = collections.defaultdict(lambda : collections.defaultdict(int) ) + + def seenK1K2(self, k1, k2): + self.dCnt[k1][k2] += 1 + + #--- First Key + def addFirstKeys(self, lk1): + """ + Make sure those key are present in the histogram, possibly with count of zero + """ + for k1 in lk1: self.dCnt[k1] + + def getFirstKeyList(self): + """ + return the sorted list of first key + """ + l = list(self.dCnt.keys()); l.sort() + return l + + #--- Second Key + def getAllSecondKeys(self): + setK = set() + for k in self.getFirstKeyList(): + setK = setK.union( self.getSecondKeyList(k) ) + return list(setK) + + def getSecondKeyList(self, k): + """ + return the sorted list of observed labels for this tag + """ + l = list(self.dCnt[k].keys()); l.sort() + return l + + def getSecondKeyCountList(self, k): + """ + return the count of observed second keys, in same order as the second key list, for that first key + """ + return [self.dCnt[k][v] for v in self.getSecondKeyList(k)] + + def getCount(self, k1, k2): return self.dCnt[k1][k2] + + #--- Sum + def getSumByFirstKey(self, k1): + """ + return the sum of counts of observed second keys, for that first key + """ + return sum( self.dCnt[k1][v] for v in self.getSecondKeyList(k1) ) + + def getSumBySecondKey(self, k2): + """ + return the sum of counts of observed first keys, for that second key + """ + cnt = 0 + for k1 in self.getFirstKeyList(): + if k2 in self.getSecondKeyList(k1): cnt += self.getCount(k1, k2) + return cnt + +class CollectionAnalyzer: + def __init__(self, lTag): + self.start() + self.lTag = lTag #all tag names + + def start(self): + """ + reset any accumulated data + """ + self.hPageCountPerDoc = DoubleHistogram("Page count stat") + self.hTagCountPerDoc = DoubleHistogram("Tag stat per document") + self.hLblCountPerTag = DoubleHistogram("Label stat per tag") + + self.lDoc = None #all doc names + self.lNbPage = None + + def runPageXml(self, sDir): + """ + process one folder per document + """ + assert False, "Method must be specialized" + + def runMultiPageXml(self, sDir): + """ + process one PXML per document + """ + assert False, "Method must be specialized" + + def end(self): + """ + Consolidate the gathered data + """ + self.lDoc = self.hPageCountPerDoc.getFirstKeyList() #all doc are listed here + self.hTagCountPerDoc.addFirstKeys(self.lDoc) #to make sure we have all of them listed, even those without tags of interest + self.lObservedTag = self.hTagCountPerDoc.getAllSecondKeys() #all tag of interest observed in dataset + + self.lNbPage = list() + for doc in self.lDoc: + lNb = self.hPageCountPerDoc.getSecondKeyList(doc) + assert len(lNb) == 1 + self.lNbPage.append(lNb[0]) + #label list per tag: self.hLblCountPerTag.getSecondKeyList(tag) + + def save(self, filename): + t = (self.hPageCountPerDoc, self.hTagCountPerDoc, self.hLblCountPerTag) + with open(filename, "wb") as fd: pickle.dump(t, fd) + + def load(self, filename): + with open(filename, "rb")as fd: + self.hPageCountPerDoc, self.hTagCountPerDoc, self.hLblCountPerTag = pickle.load(fd) + + def prcnt(self, num, totnum): + if totnum==0: + return "n/a" + else: + f = num*100.0/totnum + if 0.0 < f and f < 2.0: + return "%.1f%%" % f + else: + return "%.0f%%" % f + + def report(self): + """ + report on accumulated data so far + """ + print( "-"*60) + + print( " ----- %d documents, %d pages" %(len(self.lDoc), sum(self.lNbPage))) + for doc, nb in zip(self.lDoc, self.lNbPage): + print( "\t---- %40s %6d pages"%(doc, nb)) + + print() + print( " ----- %d objects of interest (%d observed): %s"%(len(self.lTag), len(self.lObservedTag), self.lTag)) + for doc in self.lDoc: + print( "\t---- %s %6d occurences"%(doc, self.hTagCountPerDoc.getSumByFirstKey(doc))) + for tag in self.lObservedTag: + print( "\t\t--%20s %6d occurences" %(tag, self.hTagCountPerDoc.getCount(doc, tag))) + print() + for tag in self.lObservedTag: + print( "\t-- %s %6d occurences" %(tag, self.hTagCountPerDoc.getSumBySecondKey(tag))) + for doc in self.lDoc: + print( "\t\t---- %40s %6d occurences"%(doc, self.hTagCountPerDoc.getCount(doc, tag))) + + print() + print( " ----- Label frequency for ALL %d objects of interest: %s"%(len(self.lTag), self.lTag)) + for tag in self.lTag: + totnb = self.hTagCountPerDoc.getSumBySecondKey(tag) + totnblabeled = self.hLblCountPerTag.getSumByFirstKey(tag) + print( "\t-- %s %6d occurences %d labelled" %(tag, totnb, totnblabeled)) + for lbl in self.hLblCountPerTag.getSecondKeyList(tag): + nb = self.hLblCountPerTag.getCount(tag, lbl) + print( "\t\t- %20s %6d occurences\t(%5s) (%5s)"%(lbl, + nb, + self.prcnt(nb, totnb), + self.prcnt(nb, totnblabeled))) + nb = totnb - totnblabeled + lbl="" + print( "\t\t- %20s %6d occurences\t(%5s)"%(lbl, nb, self.prcnt(nb, totnb))) + + print( "-"*60) + return "" + + def seenDocPageCount(self, doc, pagecnt): + self.hPageCountPerDoc.seenK1K2(doc, pagecnt) #strange way to indicate the page count of a doc.... + def seenDocTag(self, doc, tag): + self.hTagCountPerDoc.seenK1K2(doc, tag) + def seenTagLabel(self, tag, lbl): + self.hLblCountPerTag.seenK1K2(tag, lbl) + +class PageXmlCollectionAnalyzer(CollectionAnalyzer): + """ + Annalyse a collection of PageXmlDocuments + """ + + dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} + + def __init__(self, sDocPattern, sPagePattern, lTag, sCustom=None): + """ + sRootDir is the root directory of the collection + sDocPattern is the pattern followed by folders, assuming one folder contains one document + sPagePattern is the pattern followed by each PageXml file , assuming one file contains one PageXml XML + ltTagAttr is the list of pair of tag of interest and label attribute + """ + CollectionAnalyzer.__init__(self, lTag) + self.sDocPattern = sDocPattern + self.sPagePattern = sPagePattern + self.lTag = lTag + self.sCustom = sCustom + self.ltCRES = [] #list of tuple (cre, replacement-string) + + def setLabelPattern(self, sRE, sRepl): + """ + replace any occurence of the pattern by the replacement string in a label + """ + self.ltCRES.append( (re.compile(sRE), sRepl) ) + + def runPageXml(self, sRootDir): + lFolder = [os.path.basename(folder) for folder in glob.iglob(os.path.join(sRootDir, self.sDocPattern)) + if os.path.isdir(folder)] + lFolder.sort() + print( "Documents: ", lFolder) + + for docdir in lFolder: + print( "Document ", docdir) + lPageFile = [os.path.basename(name) for name in glob.iglob(os.path.join(sRootDir, docdir, self.sPagePattern)) + if os.path.isfile(os.path.join(sRootDir, docdir, name))] + lPageFile.sort() + self.seenDocPageCount(docdir, len(lPageFile)) + for sPageFile in lPageFile: + print( ".",) + doc = etree.parse(os.path.join(sRootDir, docdir, sPageFile)) + self.parsePage(doc, doc.getroot(), docdir) + doc = None + gc.collect() + print() + sys.stdout.flush() + + def runMultiPageXml(self, sRootDir): + print( os.path.join(sRootDir, self.sDocPattern)) + print( glob.glob(os.path.join(sRootDir, self.sDocPattern))) + lDocFile = [os.path.basename(filename) for filename in glob.iglob(os.path.join(sRootDir, self.sDocPattern)) + if os.path.isfile(filename)] + lDocFile.sort() + print( "Documents: ", lDocFile) + + for docFile in lDocFile: + print( "Document ", docFile) + doc = etree.parse(os.path.join(sRootDir, docFile)) + lNdPage = doc.getroot().xpath("//pg:Page", + namespaces=self.dNS) + self.seenDocPageCount(docFile, len(lNdPage)) + for ndPage in lNdPage: + print( ".",) + self.parsePage(doc, ndPage, docFile) + print() + sys.stdout.flush() + + def parsePage(self, doc, ctxtNd, name): + for tag in self.lTag: + lNdTag = ctxtNd.xpath(".//pg:%s"%tag, namespaces=self.dNS) + for nd in lNdTag: + self.seenDocTag(name, tag) + if self.sCustom != None: + if self.sCustom == "": + try: + lbl = PageXml.getCustomAttr(nd, "structure", "type") + except: + lbl = '' + else: + lbl = nd.get(self.sCustom) + else: + lbl = nd.get("type") + + if lbl: + for cre, sRepl in self.ltCRES: lbl = cre.sub(sRepl, lbl) #pattern processing + self.seenTagLabel(tag, lbl) + + +def test_simple(): + sTESTS_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "tests") + + sDATA_DIR = os.path.join(sTESTS_DIR, "data") + + doer = PageXmlCollectionAnalyzer("*.mpxml", + None, + ["Page", "TextRegion", "TextLine"], + #["type"], + sCustom="") + doer.start() + doer.runMultiPageXml(os.path.join(sDATA_DIR, "abp_TABLE_9142_mpxml", "col")) + doer.end() + sReport = doer.report() + print( sReport) + +if __name__ == "__main__": + + if False: + test_simple() + + sUsage="""Usage: %s [sPagePattern]) +For Multi-PageXml, only root directory and document pattern (2 arguments, e.g. 9142_GTRC/col '*.mpxml' ) +For PageXml, give also the Xml page pattern (3 arguments, e.g. 9142_GTRC/col '[0-9]+' *.mpxml') +"""%sys.argv[0] + + #prepare for the parsing of the command line + parser = OptionParser(usage=sUsage) + +# parser.add_option("--dir", dest='lTrn', action="store", type="string" +# , help="Train or continue previous training session using the given annotated collection.") +# parser.add_option("--tst", dest='lTst', action="store", type="string" +# , help="Test a model using the given annotated collection.") +# parser.add_option("--run", dest='lRun', action="store", type="string" +# , help="Run a model on the given non-annotated collection.") +# parser.add_option("-w", "--warm", dest='warm', action="store_true" +# , help="Attempt to warm-start the training") + parser.add_option("-c", "--custom", dest='custom', action="store", type="string" + , help="With --custom= , it reads @custom Xml attribute instead of @type, or if you specify --custom=toto, it will read the @toto attribute.") + parser.add_option("--pattern", dest='pattern', action="store" + , help="Replace the given pattern in the label by # (specific for BAR so far...)") + + # --- +# bMODEUN = True + + #parse the command line + (options, args) = parser.parse_args() + # --- + try: + try: + sRootDir, sDocPattern, sPagePattern = args[0:3] + bMultiPageXml = False + except: + sRootDir, sDocPattern = args[0:2] + bMultiPageXml = True + sPagePattern = None + except: + print(sUsage) + exit(1) + + #all tag supporting the attribute type in PageXml 2003 + lTag = ["Page", "TextRegion", "GraphicRegion", "CharRegion", "RelationType"] + #Pragmatism: don't think we will have annotatetd page + lTag = ["TextRegion", "GraphicRegion", "CharRegion", "RelationType"] + #Pragmatism: we may also have tagged TextLine ... + lTag.append("TextLine") + + print( sRootDir, sDocPattern, sPagePattern, lTag) + +# if bMODEUN: +# #all tag supporting the attribute type in PageXml 2003 +# ltTagAttr = [ (name, "type") for name in ["Page", "TextRegion", "GraphicRegion", "CharRegion", "RelationType"]] +# else: +# ls = args[3:] +# ltTagAttr = zip(ls[slice(0, len(ls), 2)], ls[slice(1, len(ls), 2)]) +# print( sRootDir, sDocPattern, sPagePattern, ltTagAttr) +# except: +# # if bMODEUN: +# # print( "Usage: %s sRootDir sDocPattern [sPagePattern]"%(sys.argv[0] )) +# # else: +# # print( "Usage: %s sRootDir sDocPattern [sPagePattern] [Tag Attr]+"%(sys.argv[0] )) +# exit(1) + + doer = PageXmlCollectionAnalyzer(sDocPattern, sPagePattern, lTag, sCustom=options.custom) + if options.pattern != None: + doer.setLabelPattern(options.pattern, "#") + + doer.start() + if bMultiPageXml: + print( "--- MultiPageXml ---") + doer.runMultiPageXml(sRootDir) + else: + print( "--- PageXml ---") + doer.runPageXml(sRootDir) + + doer.end() + sReport = doer.report() + + print( sReport) + diff --git a/TranskribusDU/tasks/DU_split_collection.py b/TranskribusDU/tasks/DU_split_collection.py index b9add87..2d781ae 100644 --- a/TranskribusDU/tasks/DU_split_collection.py +++ b/TranskribusDU/tasks/DU_split_collection.py @@ -3,20 +3,9 @@ """ DU task: split a collection in N equal parts, at random - Copyright NAVER(C) 2019 Jean-Luc Meunier + Copyright Xerox(C) 2019 Jean-Luc Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme @@ -26,6 +15,8 @@ import sys, os, random from shutil import copyfile +from optparse import OptionParser +import math try: #to ease the use without proper Python installation import TranskribusDU_version @@ -41,15 +32,27 @@ if __name__ == "__main__": # import better_exceptions # better_exceptions.MAX_LENGTH = None + sUsage = """ +USAGE: %s DIR ( N | p1,p2(,p)+ ) +Split in N folders + or +Split in folders following the proportions p1, ... pN + +The folders are named after the DIR folder by adding suffix_part_<1 to N> + +(Expecting to find a 'col' subfolder in DIR)""" % sys.argv[0] + + parser = OptionParser(usage=sUsage) + (options, args) = parser.parse_args() try: - sDir = sys.argv[1] - n = int(sys.argv[2]) + sDir, sN = args except: - print("USAGE: %s DIR N"%sys.argv[0]) + print(sUsage) exit(1) sColDir= os.path.join(sDir, "col") + assert os.path.isdir(sColDir), "%s is not a folder"%sColDir print("- looking at ", sColDir) lsFile = [] @@ -60,16 +63,36 @@ if not(_fnl.endswith(".mpxml") or _fnl.endswith(".pxml")): continue lsFile.append(_fn) - traceln(" %d files to split in %d parts" % (len(lsFile), n)) + + nbFile = len(lsFile) + try: + lP = [int(_s) for _s in sN.split(',')] + if len(lP) < 2: raise ValueError("want to run the except code") + lP = [p / sum(lP) for p in lP] + traceln(" %d files to split in %d parts with proportions %s" % ( + nbFile + , len(lP) + , ",".join("%.2f"%_p for _p in lP))) + lP.sort() + ld = [] + for i, p in enumerate(lP): + ld += [i] * math.ceil(p * nbFile) + ld = ld[:nbFile] + while len(ld) < nbFile: ld.append(len(lP)-1) + random.shuffle(ld) + except ValueError: + # Split in N parts + traceln(" %d files to split in %d parts" % (nbFile, int(sN))) + n = int(sN) - N = len(lsFile) - ld = getSplitIndexList(N, n, traceln) - assert len(ld) == N + ld = getSplitIndexList(nbFile, n, traceln) + assert len(ld) == nbFile - # *** SHUFFLING!! *** - random.shuffle(ld) + # *** SHUFFLING!! *** + random.shuffle(ld) + # ld [I] gives the folder index where to put the Ith file def get_sToColDir(sDir, d, bExistIsOk=False): """ @@ -92,10 +115,11 @@ def get_sToColDir(sDir, d, bExistIsOk=False): raise Exception("First remove the destination folders: ", (sToDir, sToColDir)) return sToColDir + assert len(ld) == len(lsFile) # make sure the folder are not already containing some stuff (from previous runs...) - for _d in range(1, n+1): - get_sToColDir(sDir, _d, bExistIsOk=False) + for _d in set(ld): + get_sToColDir(sDir, _d+1, bExistIsOk=False) ld = [1+d for d in ld] # convenience for d, sFilename in zip(ld, lsFile): diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableAnnotation.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableAnnotation.py index 2d37159..e115a5f 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableAnnotation.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableAnnotation.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2017 H. Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -223,16 +212,16 @@ def annotateDocument(self,lsTrnColDir): ## TEXT for tl in lTextLine: try: - sLabel = tl.type.parseDomNodeLabel(tl.node) + sLabel = tl.type.parseDocNodeLabel(tl) # cls = DU_GRAPH._dClsByLabel[sLabel] #Here, if a node is not labelled, and no default label is set, then KeyError!!! # except KeyError: except ValueError: tl.node.setProp(tl.type.sLabelAttr,lLabels[4]) ## SEP for sep in lSeparator: -# sLabel = sep.type.parseDomNodeLabel(sep.node) +# sLabel = sep.type.parseDocNodeLabel(sep) try: - sLabel = sep.type.parseDomNodeLabel(sep.node) + sLabel = sep.type.parseDocNodeLabel(sep) # cls = DU_GRAPH._dClsByLabel[sLabel] #Here, if a node is not labelled, and no default label is set, then KeyError!!! except ValueError: sep.node.setProp(sep.type.sLabelAttr,lLabels[6]) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableCutPredictor.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableCutPredictor.py index f1b45e4..82bf3a9 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableCutPredictor.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableCutPredictor.py @@ -70,18 +70,7 @@ Copyright Naver Labs Europe 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableGrid.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableGrid.py index faac880..4aa8ea5 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableGrid.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableGrid.py @@ -6,18 +6,7 @@ Copyright Naver Labs Europe 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableH.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableH.py index b89f1b1..c309212 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableH.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableH.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableR.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableR.py index 3c511d0..a8593ec 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableR.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableR.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRC.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRC.py index eca280d..28d521f 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRC.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRC.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCAnnotation_checker.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCAnnotation_checker.py index 5791ed1..eef79ab 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCAnnotation_checker.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCAnnotation_checker.py @@ -6,18 +6,7 @@ Copyright Naver Labs Europe 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut.py index 583c0f7..ca22626 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut.py @@ -9,18 +9,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -85,7 +74,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -119,7 +108,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1.py index edbf959..c55f925 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1.py @@ -9,18 +9,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -87,7 +76,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -121,7 +110,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] @@ -648,7 +637,7 @@ def evalClusterByRow(self, sFilename): lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] #load the block nodes per page - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1SIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1SIO.py index 908bb32..bd55b3b 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1SIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut1SIO.py @@ -11,18 +11,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -89,7 +78,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -125,7 +114,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] @@ -652,7 +641,7 @@ def evalClusterByRow(self, sFilename): lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] #load the block nodes per page - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut2.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut2.py index aeafda4..19d910f 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut2.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRCut2.py @@ -9,18 +9,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -85,7 +74,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -119,7 +108,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG.py index 0dae259..6a96008 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -83,7 +72,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -108,7 +97,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG2.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG2.py index 946c731..c3644ea 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG2.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG2.py @@ -7,18 +7,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG3.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG3.py index d011a0b..34ce0aa 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG3.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG3.py @@ -7,18 +7,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -85,7 +74,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -110,7 +99,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG4.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG4.py index d1fd2ea..e3ace94 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG4.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG4.py @@ -16,18 +16,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -95,7 +84,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -128,7 +117,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG41.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG41.py index 70b80cf..d6e93e4 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG41.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG41.py @@ -16,18 +16,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -95,7 +84,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -128,7 +117,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG42.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG42.py index bf8e500..fe38508 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG42.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRG42.py @@ -16,18 +16,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -96,7 +85,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -129,7 +118,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRGw.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRGw.py index d95be48..fce8476 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRGw.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRGw.py @@ -18,18 +18,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -96,7 +85,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -121,7 +110,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for pnum, page, domNdPage in self._iter_Page_DomNode(self.doc): + for pnum, page, domNdPage in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRH.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRH.py index 1fb1c32..3ad9f73 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRH.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRH.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRHCut1SIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRHCut1SIO.py index bf32679..640a441 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRHCut1SIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRHCut1SIO.py @@ -11,18 +11,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -112,7 +101,7 @@ def setFactoredClassicalType(cls, ntClassic, ntFactored): cls._dFactorialType[ntClassic] = ntFactored cls._lfactoredType.append(ntFactored) - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -154,7 +143,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): assert len(lClassicType) == 1 assert len(lSpecialType) == 1 - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] @@ -202,7 +191,7 @@ def computeSpecialEdges(cls, lClassicPageNode, lSpecialPageNode): # ------------------------------------ - def parseDomLabels(self): + def parseDocLabels(self): """ Parse the label of the graph from the dataset, and set the node label return the set of observed class (set of integers in N+) @@ -213,13 +202,13 @@ def parseDomLabels(self): == ad-hoc graph == We also load the class of the factored classical nodes """ - setSeensLabels = Graph_MultiPageXml.parseDomLabels(self) + setSeensLabels = Graph_MultiPageXml.parseDocLabels(self) # and we go thru the classical node types to also load the factored label for nd in self.lNodeBlock: factoredType = self._dFactorialType[nd.type] try: - sFactoredLabel = factoredType.parseDomNodeLabel(nd.node) + sFactoredLabel = factoredType.parseDocNodeLabel(nd) except KeyError: raise ValueError("Page %d, unknown label in %s (Known labels are %s)"%(nd.pnum, str(nd.node), self._dClsByLabel)) factoredLabel = self._dClsByLabel[sFactoredLabel] @@ -228,7 +217,7 @@ def parseDomLabels(self): setSeensLabels.add(factoredLabel) return setSeensLabels - def setDomLabels(self, Y): + def setDocLabels(self, Y): """ Set the labels of the graph nodes from the Y matrix return the DOM @@ -245,18 +234,18 @@ def setDomLabels(self, Y): # Blocks for i, nd in enumerate(self.lNodeBlock): sLabel = self._dLabelByCls[ Y[i] ] - ntBlock.setDomNodeLabel(nd.node, sLabel) + ntBlock.setDocNodeLabel(nd, sLabel) # factored Blocks for i, nd in enumerate(self.lNodeBlock): sLabel = self._dLabelByCls[ Y[i+NB] ] - ntFactored.setDomNodeLabel(nd.node, sLabel) + ntFactored.setDocNodeLabel(nd, sLabel) # cut nodes Z = NB + NB for i, nd in enumerate(self.lNodeCutLine): sLabel = self._dLabelByCls[ Y[i+Z] ] - ntCut.setDomNodeLabel(nd.node, sLabel) + ntCut.setDocNodeLabel(nd, sLabel) return self.doc @@ -873,7 +862,7 @@ def evalClusterByRow(self, sFilename): lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] #load the block nodes per page - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO.py index 665d0bd..8762bb3 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -48,20 +37,21 @@ from tasks.DU_CRF_Task import DU_CRF_Task #from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText -from crf.FeatureDefinition_PageXml_std_noText_v4 import FeatureDefinition_PageXml_StandardOnes_noText_v4 +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText class NodeType_BIESO_to_SIO(NodeType_PageXml_type_woText): """ Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node + sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'S', @@ -155,7 +145,7 @@ def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs } , sComment=sComment #,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText - ,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText_v4 + ,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText ) #self.setNbClass(3) #so that we check if all classes are represented in the training set diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIOH.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIOH.py index 764c0ce..bbffab6 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIOH.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIOH.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -56,13 +45,13 @@ class NodeType_BIESO_to_SIOH(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get("DU_header") if sXmlLabel != 'CH': diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_Cut1SIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_Cut1SIO.py index e90b5c1..11d50f8 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_Cut1SIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_Cut1SIO.py @@ -11,18 +11,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -89,7 +78,7 @@ def setClassicNodeTypeList(cls, lNodeType): """ cls._lClassicNodeType = lNodeType - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -125,7 +114,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] @@ -537,13 +526,13 @@ class NodeType_BIESO_to_SIO(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'S', diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H.py index 4075e21..6c12a9c 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -58,12 +47,13 @@ class NodeType_BIESO_to_SIO_and_CHDO(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_HCut1SIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_HCut1SIO.py index 5d16923..888f88a 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_HCut1SIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_HCut1SIO.py @@ -11,18 +11,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -112,7 +101,7 @@ def setFactoredClassicalType(cls, ntClassic, ntFactored): cls._dFactorialType[ntClassic] = ntFactored cls._lfactoredType.append(ntFactored) - def parseXmlFile(self, sFilename, iVerbose=0): + def parseDocFile(self, sFilename, iVerbose=0): """ Load that document as a CRF Graph. Also set the self.doc variable! @@ -154,7 +143,7 @@ def parseXmlFile(self, sFilename, iVerbose=0): assert len(lClassicType) == 1 assert len(lSpecialType) == 1 - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] @@ -202,7 +191,7 @@ def computeSpecialEdges(cls, lClassicPageNode, lSpecialPageNode): # ------------------------------------ - def parseDomLabels(self): + def parseDocLabels(self): """ Parse the label of the graph from the dataset, and set the node label return the set of observed class (set of integers in N+) @@ -213,13 +202,13 @@ def parseDomLabels(self): == ad-hoc graph == We also load the class of the factored classical nodes """ - setSeensLabels = Graph_MultiPageXml.parseDomLabels(self) + setSeensLabels = Graph_MultiPageXml.parseDocLabels(self) # and we go thru the classical node types to also load the factored label for nd in self.lNodeBlock: factoredType = self._dFactorialType[nd.type] try: - sFactoredLabel = factoredType.parseDomNodeLabel(nd.node) + sFactoredLabel = factoredType.parseDocNodeLabel(nd) except KeyError: raise ValueError("Page %d, unknown label in %s (Known labels are %s)"%(nd.pnum, str(nd.node), self._dClsByLabel)) factoredLabel = self._dClsByLabel[sFactoredLabel] @@ -228,7 +217,7 @@ def parseDomLabels(self): setSeensLabels.add(factoredLabel) return setSeensLabels - def setDomLabels(self, Y): + def setDocLabels(self, Y): """ Set the labels of the graph nodes from the Y matrix return the DOM @@ -245,18 +234,18 @@ def setDomLabels(self, Y): # Blocks for i, nd in enumerate(self.lNodeBlock): sLabel = self._dLabelByCls[ Y[i] ] - ntBlock.setDomNodeLabel(nd.node, sLabel) + ntBlock.setDocNodeLabel(nd, sLabel) # factored Blocks for i, nd in enumerate(self.lNodeBlock): sLabel = self._dLabelByCls[ Y[i+NB] ] - ntFactored.setDomNodeLabel(nd.node, sLabel) + ntFactored.setDocNodeLabel(nd, sLabel) # cut nodes Z = NB + NB for i, nd in enumerate(self.lNodeCutLine): sLabel = self._dLabelByCls[ Y[i+Z] ] - ntCut.setDomNodeLabel(nd.node, sLabel) + ntCut.setDocNodeLabel(nd, sLabel) return self.doc @@ -738,12 +727,13 @@ class NodeType_BIESO_to_SIO_and_CHDO(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) @@ -911,7 +901,7 @@ def evalClusterByRow(self, sFilename): lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] #load the block nodes per page - for (pnum, page, domNdPage) in self._iter_Page_DomNode(self.doc): + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): #now that we have the page, let's create the node for each type! lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H_v2.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H_v2.py index b68fee2..44b97e8 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H_v2.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_H_v2.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -139,12 +128,13 @@ class NodeType_BIESO_to_SIO_and_CHDO(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_v2.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_v2.py index 2387de4..3306f60 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_v2.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableRSIO_v2.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2018 H. Déjean, JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -140,13 +129,13 @@ class NodeType_BIESO_to_SIO(NodeType_PageXml_type_woText): Convert BIESO labeling to SIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'S', diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed.py new file mode 100644 index 0000000..421c369 --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed.py @@ -0,0 +1,1051 @@ +# -*- coding: utf-8 -*- + +""" + DU task for ABP Table: + doing jointly row EIO and near horizontal cuts SIO + + block2line edges do not cross another block. + + The cut are based on baselines of text blocks, with some positive or negative inclination. + + - the labels of cuts are SIO + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os +import math +from lxml import etree +from collections import Counter +from ast import literal_eval + +import numpy as np +import shapely.geometry as geom +import shapely.ops + +from sklearn.pipeline import Pipeline, FeatureUnion +from sklearn.feature_extraction.text import CountVectorizer + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _checkFindColDir, _exit +from tasks.DU_CRF_Task import DU_CRF_Task +from tasks.DU_Table.DU_ABPTableSkewed_CutAnnotator import SkewedCutAnnotator + +from xml_formats.PageXml import MultiPageXml, PageXml + +import graph.GraphModel +from graph.Block import Block +from graph.Edge import Edge, SamePageEdge, HorizontalEdge, VerticalEdge +from graph.Graph_MultiPageXml import Graph_MultiPageXml +from graph.NodeType_PageXml import NodeType_PageXml_type + +#from graph.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +from graph.FeatureDefinition import FeatureDefinition +from graph.Transformer import Transformer, TransformerListByType, SparseToDense +from graph.Transformer import EmptySafe_QuantileTransformer as QuantileTransformer +from graph.Transformer_PageXml import NodeTransformerXYWH_v2, NodeTransformerNeighbors, Node1HotFeatures_noText,\ + NodeTransformerText, NodeTransformerTextLen, EdgeNumericalSelector_v2 +from graph.Transformer_PageXml import Edge1HotFeatures_noText, EdgeBooleanFeatures_v2, EdgeNumericalSelector_noText +from graph.PageNumberSimpleSequenciality import PageNumberSimpleSequenciality + +from util.Shape import ShapeLoader + +class GraphSkewedCut(Graph_MultiPageXml): + """ + We specialize the class of graph because the computation of edges is quite specific + + Here we consider horizontal and near-horizontal cuts + """ + bCutAbove = False # the cut line is above the "support" text + lRadAngle = None + + #Cut stuff + #iModulo = 1 # map the coordinate to this modulo + fMinPageCoverage = 0.5 # minimal coverage to consider a GT table separator + # fCutHeight = 25 # height of a cutting ribbon + # For NAF to get 91% GT recall with same recall on ABP 98% (moving from 105 to 108% cuts) + fCutHeight = 10 # height of a cutting ribbon + + # BAAAAD iLineVisibility = 5 * 11 # a cut line sees other cut line up to N pixels downward + iLineVisibility = 3700 // 7 # (528) a cut line sees other cut line up to N pixels downward + iBlockVisibility = 3*7*13 # (273) a block sees neighbouring cut lines at N pixels + + _lClassicNodeType = None + + # when loading a text, we create a shapely shape using the function below. + shaper_fun = ShapeLoader.node_to_Point + + @classmethod + def setClassicNodeTypeList(cls, lNodeType): + """ + determine which type of node goes thru the classical way for determining + the edges (vertical or horizontal overlap, with occlusion, etc.) + """ + cls._lClassicNodeType = lNodeType + + def parseDocFile(self, sFilename, iVerbose=0): + """ + Load that document as a CRF Graph. + Also set the self.doc variable! + + CAUTION: DOES NOT WORK WITH MULTI-PAGE DOCUMENTS... + + Return a CRF Graph object + """ + traceln(" ----- FILE %s ------" % sFilename) + self.doc = etree.parse(sFilename) + self.lNode, self.lEdge = list(), list() + self.lNodeBlock = [] # text node + self.lNodeCutLine = [] # cut line node + + doer = SkewedCutAnnotator(self.bCutAbove, lAngle=self.lRadAngle) + domid = 0 + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): + traceln(" --- page %s - constructing separator candidates" % pnum) + #load the page objects and the GT partition (defined by the table) if any + loBaseline, dsetTableByRow = doer.loadPage(domNdPage, shaper_fun=self.shaper_fun) + traceln(" - found %d objects on page" % (len(loBaseline))) + if loBaseline: traceln("\t - shaped as %s" % type(loBaseline[0])) + + # find almost-horizontal cuts and tag them if GT is available + loHCut = doer.findHCut(domNdPage, loBaseline, dsetTableByRow, self.fCutHeight, iVerbose) + + #create DOM node reflecting the cuts + #first clean (just in case!) + n = doer.remove_cuts_from_dom(domNdPage) + if n > 0: + traceln(" - removed %d pre-existing cut lines" % n) + + # if GT, then we have labelled cut lines in DOM + domid = doer.add_Hcut_to_Page(domNdPage, loHCut, domid) + + lClassicType = [nt for nt in self.getNodeTypeList() if nt in self._lClassicNodeType] + lSpecialType = [nt for nt in self.getNodeTypeList() if nt not in self._lClassicNodeType] + + for (pnum, page, domNdPage) in self._iter_Page_DocNode(self.doc): + traceln(" --- page %s - constructing the graph" % pnum) + #now that we have the page, let's create the node for each type! + lClassicPageNode = [nd for nodeType in lClassicType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] + lSpecialPageNode = [nd for nodeType in lSpecialType for nd in nodeType._iter_GraphNode(self.doc, domNdPage, page) ] + + self.lNode.extend(lClassicPageNode) # e.g. the TextLine objects + self.lNodeBlock.extend(lClassicPageNode) + + self.lNode.extend(lSpecialPageNode) # e.g. the cut lines! + self.lNodeCutLine.extend(lSpecialPageNode) + + #no previous page to consider (for cross-page links...) => None + lClassicPageEdge = Edge.computeEdges(None, lClassicPageNode, self.iGraphMode) + self.lEdge.extend(lClassicPageEdge) + + # Now, compute edges between special and classic objects... + lSpecialPageEdge = self.computeSpecialEdges(lClassicPageNode, + lSpecialPageNode) + self.lEdge.extend(lSpecialPageEdge) + + #if iVerbose>=2: traceln("\tPage %5d %6d nodes %7d edges"%(pnum, len(lPageNode), len(lPageEdge))) + if iVerbose>=2: + traceln("\tPage %5d"%(pnum)) + traceln("\t block: %6d nodes %7d edges (to block)" %(len(lClassicPageNode), len(lClassicPageEdge))) + traceln("\t line: %6d nodes %7d edges (from block or line)"%(len(lSpecialPageNode), len(lSpecialPageEdge))) + c = Counter(type(o).__name__ for o in lSpecialPageEdge) + l = list(c.items()) + l.sort() + traceln("\t\t", l) + if iVerbose: traceln("\t\t (%d nodes, %d edges)"%(len(self.lNode), len(self.lEdge)) ) + + return self + + def addParsedLabelToDom(self): + """ + while parsing the pages, we may have updated the standard BIESO labels + we store the possibly new label in the DOM + """ + for nd in self.lNode: + nd.type.setDocNodeLabel(nd, self._dLabelByCls[ nd.cls ]) + + def addEdgeToDoc(self, Y=None): + """ + To display the grpah conveniently we add new Edge elements + """ + import random + (pnum, page, ndPage) = next(self._iter_Page_DocNode(self.doc)) + w = int(ndPage.get("imageWidth")) + + nn = 1 + len([e for e in self.lEdge if type(e) not in [HorizontalEdge, VerticalEdge, Edge_BL]]) + ii = 0 + for edge in self.lEdge: + if type(edge) in [HorizontalEdge, VerticalEdge]: + A, B = edge.A.shape.centroid, edge.B.shape.centroid + elif type(edge) in [Edge_BL]: + A = edge.A.shape.centroid + # not readable _pt, B = shapely.ops.nearest_points(A, edge.B.shape) + _pt, B = shapely.ops.nearest_points(edge.A.shape, edge.B.shape) + else: + ii += 1 + x = 1 + ii * (w/nn) + pt = geom.Point(x, 0) + A, _ = shapely.ops.nearest_points(edge.A.shape, pt) + B, _ = shapely.ops.nearest_points(edge.B.shape, pt) + ndSep = MultiPageXml.createPageXmlNode("Edge") + ndSep.set("DU_type", type(edge).__name__) + ndPage.append(ndSep) + MultiPageXml.setPoints(ndSep, [(A.x, A.y), (B.x, B.y)]) + return + + @classmethod + def computeSpecialEdges(cls, lClassicPageNode, lSpecialPageNode): + """ + return a list of edges + """ + raise Exception("Specialize this method") + + + +class Edge_BL(Edge): + """Edge block-to-Line""" + pass + +class Edge_LL(Edge): + """Edge line-to-Line""" + pass + +class GraphSkewedCut_H(GraphSkewedCut): + """ + Only horizontal cut lines + """ + + def __init__(self): + self.showClassParam() + + @classmethod + def showClassParam(cls): + """ + show the class parameters + return whether or not they were shown + """ + try: + cls.bParamShownOnce + return False + except: + #traceln(" - iModulo : " , cls.iModulo) + traceln(" - block_see_line : " , cls.iBlockVisibility) + traceln(" - line_see_line : " , cls.iLineVisibility) + traceln(" - cut height : " , cls.fCutHeight) + traceln(" - cut above : " , cls.bCutAbove) + traceln(" - angles : " , [math.degrees(v) for v in cls.lRadAngle]) + traceln(" - fMinPageCoverage : " , cls.fMinPageCoverage) + traceln(" - Textual features : " , cls.bTxt) + cls.bParamShownOnce = True + return True + + def getNodeListByType(self, iTyp): + if iTyp == 0: + return self.lNodeBlock + else: + return self.lNodeCutLine + + def getEdgeListByType(self, typA, typB): + if typA == 0: + if typB == 0: + return (e for e in self.lEdge if isinstance(e, SamePageEdge)) + else: + return (e for e in self.lEdge if isinstance(e, Edge_BL)) + else: + if typB == 0: + return [] + else: + return (e for e in self.lEdge if isinstance(e, Edge_LL)) + + + @classmethod + def computeSpecialEdges(self, lClassicPageNode, lSpecialPageNode): + """ + Compute: + - edges between each block and the cut line above/across/below the block + - edges between cut lines + return a list of edges + """ + #augment the block with the coordinate of its baseline central point + for blk in lClassicPageNode: + try: + pt = blk.shape.centroid + blk.x_bslne = pt.x + blk.y_bslne = pt.y + except IndexError: + traceln("** WARNING: no Baseline in ", blk.domid) + traceln("** Using BB instead... :-/") + blk.x_bslne = (blk.x1+blk.x2) / 2 + blk.y_bslne = (blk.y1+blk.y2) / 2 + blk._in_edge_up = 0 # count of incoming edge from upper lines + blk._in_edge_down = 0 # count of incoming edge from downward lines + + #block to cut line edges + # no _type=0 because they are valid cut (never crossing any block) + lEdge = [] + for cutBlk in lSpecialPageNode: + #equation of the line + # y = A x + B + A = (cutBlk.y2 - cutBlk.y1) / (cutBlk.x2 - cutBlk.x1) + B = cutBlk.y1 - A * cutBlk.x1 + oCut = cutBlk.shape + for blk in lClassicPageNode: + dist = oCut.distance(blk.shape) + if dist <= self.iBlockVisibility: + edge = Edge_BL(blk, cutBlk) # Block _to_ Cut !! + # experiments show that abs helps + # edge.len = (blk.y_bslne - cutBlk.y1) / self.iBlockVisibility + edge.len = dist / self.iBlockVisibility + y = A * blk.x_bslne + B # y of the point on cut line + # edge._type = -1 if blk.y_bslne > y else (+1 if blk.y_bslne < y else 0) + # shapely can give as distance a very small number while y == 0 + edge._type = -1 if blk.y_bslne >= y else +1 + assert edge._type != 0, (str(oCut), list(blk.shape.coords), oCut.distance(blk.shape.centroid), str(blk.shape.centroid)) + lEdge.append(edge) + + #now filter those edges + n0 = len(lEdge) + #lEdge = self._filterBadEdge(lEdge, lClassicPageNode, lSpecialPageNode) + lEdge = self._filterBadEdge(lEdge, lSpecialPageNode) + + traceln(" - filtering: removed %d edges due to obstruction." % (n0-len(lEdge))) + + # add a counter of incoming edge to nodes, for features eng. + for edge in lEdge: + if edge._type > 0: + edge.A._in_edge_up += 1 + else: + edge.A._in_edge_down += 1 + + # Cut line to Cut line edges + n0 = len(lEdge) + if self.iLineVisibility > 0: + for i, A in enumerate(lSpecialPageNode): + for B in lSpecialPageNode[i+1:]: + dist = A.shape.distance(B.shape) + if dist <= self.iLineVisibility: + edge = Edge_LL(A, B) + edge.len = dist / self.iLineVisibility + lEdge.append(edge) + traceln(" - edge_LL: added %d edges." % (len(lEdge)-n0)) + + return lEdge + + + @classmethod + def _filterBadEdge(cls, lEdge, lCutLine, fRatio=0.25): + """ + We get + - a list of block2Line edges + - a sorted list of cut line + But some block should not be connected to a line due to obstruction by + another blocks. + We filter out those edges... + return a sub-list of lEdge + """ + lKeepEdge = [] + + + def isTargetLineVisible_X(edge, lEdge, fThreshold=0.9): + """ + can the source node of the edge see the target node line? + we say no if some other block obstructs half or more of the view + """ + a1, a2 = edge.A.x1, edge.A.x2 + w = a2 - a1 + minVisibility = w * fThreshold + for _edge in lEdge: + # we want a visibility window of at least 1/4 of the object A + b1, b2 = _edge.A.x1, _edge.A.x2 + vis = min(w, max(0, b1 - a1) + max(0, a2 - b2)) + if vis <= minVisibility: return False + return True + + #there are two ways for dealing with lines crossed by a block + # - either it prevents another block to link to the line (assuming an x-overlap) + # - or not (historical way) + # THIS IS THE "MODERN" way!! + + #take each line in turn + for ndLine in lCutLine: + #--- process downward edges + #TODO: index! + lDownwardAndXingEdge = [edge for edge in lEdge \ + if edge._type > 0 and edge.B == ndLine] + if lDownwardAndXingEdge: + #sort edge by source block from closest to line block to farthest + #lDownwardAndXingEdge.sort(key=lambda o: ndLine.y1 - o.A.y_bslne) + lDownwardAndXingEdge.sort(key=lambda o: ndLine.shape.distance(o.A.shape)) + + lKeepDownwardEdge = [lDownwardAndXingEdge.pop(0)] + + #now keep all edges whose source does not overlap vertically with + # the source of an edge that is kept + for edge in lDownwardAndXingEdge: + if isTargetLineVisible_X(edge, lKeepDownwardEdge): + lKeepDownwardEdge.append(edge) + lKeepEdge.extend(lKeepDownwardEdge) + + #--- process upward edges + #TODO: index! + lUpwarAndXingdEdge = [edge for edge in lEdge \ + if edge._type < 0 and edge.B == ndLine] + if lUpwarAndXingdEdge: + #sort edge by source block from closest to line -block to farthest + #lUpwarAndXingdEdge.sort(key=lambda o: o.A.y_bslne - ndLine.y2) + lUpwarAndXingdEdge.sort(key=lambda o: ndLine.shape.distance(o.A.shape)) + lKeepUpwardEdge = [lUpwarAndXingdEdge.pop(0)] + + #now keep all edges whose source does not overlap vertically with + # the source of an edge that is kept + for edge in lUpwarAndXingdEdge: + if isTargetLineVisible_X(edge, lKeepUpwardEdge): + lKeepUpwardEdge.append(edge) + + # now we keep only the edges, excluding the crossing ones + # (already included!!) + lKeepEdge.extend(edge for edge in lKeepUpwardEdge) + + #--- and include the crossing ones (that are discarded + return lKeepEdge + + +#------------------------------------------------------------------------------------------------------ +class SupportBlock_NodeTransformer(Transformer): + """ + aspects related to the "support" notion of a block versus a cut line + """ + def transform(self, lNode): +# a = np.empty( ( len(lNode), 5 ) , dtype=np.float64) +# for i, blk in enumerate(lNode): a[i, :] = [blk.x1, blk.y2, blk.x2-blk.x1, blk.y2-blk.y1, blk.fontsize] #--- 2 3 4 5 6 + a = np.empty( ( len(lNode), 2 ) , dtype=np.float64) + for i, blk in enumerate(lNode): + a[i, :] = (blk._in_edge_up, blk._in_edge_down) + return a + +#------------------------------------------------------------------------------------------------------ +class CutLine_NodeTransformer_v3(Transformer): + """ + features of a Cut line: + - horizontal or vertical. + """ + def transform(self, lNode): + #We allocate TWO more columns to store in it the tfidf and idf computed at document level. + #a = np.zeros( ( len(lNode), 10 ) , dtype=np.float64) # 4 possible orientations: 0, 1, 2, 3 + N = 6 + a = np.zeros( ( len(lNode), N ) , dtype=np.float64) # 4 possible orientations: 0, 1, 2, 3 + + for i, blk in enumerate(lNode): + page = blk.page + assert abs(blk.x2 - blk.x1) > abs(blk.y1 - blk.y2) + #horizontal + v = (blk.y1+blk.y2)/float(page.h) - 1 # to range -1, +1 + a[i,:] = (1.0, v, v*v + , blk.angle, blk.angle_freq, blk.angle_cumul_freq) +# else: +# #vertical +# v = 2*blk.x1/float(page.w) - 1 # to range -1, +1 +# a[i, N:] = (1.0, v, v*v +# ,blk.angle, blk.angle_freq, blk.angle_cumul_freq) + # traceln("CutLine_NodeTransformer_v3", a[:min(100, len(lNode)),]) + return a + +class CutLine_NodeTransformer_qty(Transformer): + """ + features of a Cut line: + - horizontal or vertical. + """ + def transform(self, lNode): + #We allocate TWO more columns to store in it the tfidf and idf computed at document level. + #a = np.zeros( ( len(lNode), 10 ) , dtype=np.float64) # 4 possible orientations: 0, 1, 2, 3 + N = 1 + a = np.zeros( ( len(lNode), 2*N ) , dtype=np.float64) # 4 possible orientations: 0, 1, 2, 3 + + for i, blk in enumerate(lNode): + assert abs(blk.x2 - blk.x1) > abs(blk.y1 - blk.y2) + #horizontal + a[i,:] = (len(blk.set_support)) + return a + + +#------------------------------------------------------------------------------------------------------ +class Block2CutLine_EdgeTransformer(Transformer): + """ + features of a block to Cut line edge: + - below, crossing, above + """ + def transform(self, lEdge): + N = 8 + a = np.zeros( ( len(lEdge), 2 * N) , dtype=np.float64) + for i, edge in enumerate(lEdge): + z = 0 if edge._type < 0 else N # _type is -1 or 1 + blk = edge.A + page = blk.page + w = float(page.w) # h = float(page.h) + x = (blk.x1 + blk.x2) / w - 1 # [-1, +1] + a[i, z:z+N] = (1.0 + , edge.len + , edge.len*edge.len + , edge.B.angle_freq + , edge.B.angle_cumul_freq + , 1.0 if edge.A.du_index in edge.B.set_support else 0.0 + , x, x * x + ) +# print(a[i,:].tolist()) + # traceln("Block2CutLine_EdgeTransformer", a[:min(100, len(lEdge)),]) + return a + +class Block2CutLine_EdgeTransformer_qtty(Transformer): + def transform(self, lEdge): + N = 3 + a = np.zeros( ( len(lEdge), 2 * N) , dtype=np.float64) + for i, edge in enumerate(lEdge): + z = 0 if edge._type < 0 else N # _type is -1 or 1 + a[i, z:z+N] = (len(edge.B.set_support) + , edge.A._in_edge_up + , edge.A._in_edge_down + ) +# print(a[i,:].tolist()) + # traceln("Block2CutLine_EdgeTransformer", a[:min(100, len(lEdge)),]) + return a + +class Block2CutLine_FakeEdgeTransformer(Transformer): + """ + a fake transformer that return as many features as the union of real ones above + """ + def transform(self, lEdge): + assert not(lEdge) + return np.zeros( ( len(lEdge), 2*8 + 2*3) , dtype=np.float64) + + +class CutLine2CutLine_EdgeTransformer(Transformer): # ***** USELESS ***** + """ + features of a block to Cut line edge: + - below, crossing, above + """ +# BEST SO FAR +# def transform(self, lEdge): +# a = np.zeros( ( len(lEdge), 4 ) , dtype=np.float64) +# for i, edge in enumerate(lEdge): +# a[i,:] = (1, edge.len, edge.len * edge.len, int(edge.len==0)) +# # traceln("CutLine2CutLine_EdgeTransformer", a[:min(100, len(lEdge)),]) +# return a + +# WORSE +# def transform(self, lEdge): +# a = np.zeros( ( len(lEdge), 12) , dtype=np.float64) +# for i, edge in enumerate(lEdge): +# dAngle = (edge.A.angle - edge.B.angle) / 5 # we won't go beyond +-5 degrees. +# iSameSupport = int(len(edge.B.set_support.intersection(edge.A.set_support)) > 0) +# iCrosses = int(edge.A.shape.crosses(edge.B.shape)) +# a[i,:] = (1 +# , edge.len, edge.len * edge.len, int(edge.len==0), int(edge.len < 5) +# , dAngle, dAngle * dAngle, int(abs(dAngle) < 0.1), int(abs(dAngle) < 0.1) +# , iSameSupport +# , iCrosses +# , (1-iSameSupport) * iCrosses # not same support but crossing each other +# ) +# return a + + def transform(self, lEdge): + a = np.zeros( ( len(lEdge), 7 ) , dtype=np.float64) + for i, edge in enumerate(lEdge): + dAngle = (edge.A.angle - edge.B.angle) / 5 # we won't go beyond +-5 degrees. + iSameSupport = int(len(edge.B.set_support.intersection(edge.A.set_support)) > 0) + iCrosses = int(edge.A.shape.crosses(edge.B.shape)) + a[i,:] = (1, edge.len, edge.len * edge.len + , dAngle, dAngle * dAngle + , iSameSupport + , iCrosses + ) + # traceln("CutLine2CutLine_EdgeTransformer", a[:min(100, len(lEdge)),]) + return a + + + +class My_FeatureDefinition_v3_base(FeatureDefinition): + n_QUANTILES = 16 + n_QUANTILES_sml = 8 + + def __init__(self, **kwargs): + """ + set _node_transformer, _edge_transformer, tdifNodeTextVectorizer + """ + FeatureDefinition.__init__(self) + self._node_transformer = None + self._edge_transformer = None + self._node_text_vectorizer = None #tdifNodeTextVectorizer + + def fitTranformers(self, lGraph,lY=None): + """ + Fit the transformers using the graphs, but TYPE BY TYPE !!! + return True + """ + self._node_transformer[0].fit([nd for g in lGraph for nd in g.getNodeListByType(0)]) + self._node_transformer[1].fit([nd for g in lGraph for nd in g.getNodeListByType(1)]) + + self._edge_transformer[0].fit([e for g in lGraph for e in g.getEdgeListByType(0, 0)]) + self._edge_transformer[1].fit([e for g in lGraph for e in g.getEdgeListByType(0, 1)]) + self._edge_transformer[2].fit([e for g in lGraph for e in g.getEdgeListByType(1, 0)]) + self._edge_transformer[3].fit([e for g in lGraph for e in g.getEdgeListByType(1, 1)]) + + return True + +class My_FeatureDefinition_v3(My_FeatureDefinition_v3_base): + """ + Multitype version: + so the node_transformer actually is a list of node_transformer of length n_class + the edge_transformer actually is a list of node_transformer of length n_class^2 + + We also inherit from FeatureDefinition_T !!! + """ + + def __init__(self, **kwargs): + """ + set _node_transformer, _edge_transformer, tdifNodeTextVectorizer + """ + My_FeatureDefinition_v3_base.__init__(self) + + nbTypes = self._getTypeNumber(kwargs) + + block_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("xywh", Pipeline([ + ('selector', NodeTransformerXYWH_v2()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("edge_cnt", Pipeline([ + ('selector', SupportBlock_NodeTransformer()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('edge_cnt', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + , ("neighbors", Pipeline([ + ('selector', NodeTransformerNeighbors()), + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("1hot", Pipeline([ + ('1hot', Node1HotFeatures_noText()) #does the 1-hot encoding directly + ]) + ) + ]) + + Cut_line_transformer = FeatureUnion( [ + ("std", CutLine_NodeTransformer_v3()) + , ("qty", Pipeline([ + ('selector', CutLine_NodeTransformer_qty()), + ('quantile', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + ]) + + self._node_transformer = TransformerListByType([block_transformer, Cut_line_transformer]) + + edge_BB_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("1hot", Pipeline([ + ('1hot', Edge1HotFeatures_noText(PageNumberSimpleSequenciality())) + ]) + ) + , ("boolean", Pipeline([ + ('boolean', EdgeBooleanFeatures_v2()) + ]) + ) + , ("numerical", Pipeline([ + ('selector', EdgeNumericalSelector_noText()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ] ) + #edge_BL_transformer = Block2CutLine_EdgeTransformer() + edge_BL_transformer = FeatureUnion( [ + ("std", Block2CutLine_EdgeTransformer()) + , ("qty", Pipeline([ + ('selector', Block2CutLine_EdgeTransformer_qtty()), + ('quantile', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + ]) + + edge_LL_transformer = CutLine2CutLine_EdgeTransformer() + self._edge_transformer = TransformerListByType([edge_BB_transformer, + edge_BL_transformer, + # edge_BL_transformer, # useless but required + Block2CutLine_FakeEdgeTransformer(), # fit is called with [], so the Pipeline explodes + edge_LL_transformer + ]) + + + +gTBL = str.maketrans("0123456789", "NNNNNNNNNN") +def My_FeatureDefinition_v3_txt_preprocess(s): + """ + Normalization of the etxt before extracting ngrams + """ + return s.lower().translate(gTBL) + + +class My_FeatureDefinition_v3_txt(My_FeatureDefinition_v3_base): + """ + Multitype version: + so the node_transformer actually is a list of node_transformer of length n_class + the edge_transformer actually is a list of node_transformer of length n_class^2 + + We also inherit from FeatureDefinition_T !!! + """ + t_ngrams_range = (2, 4) + n_ngrams = 1000 + + # pre-processing of text before extracting ngrams + def __init__(self, **kwargs): + """ + set _node_transformer, _edge_transformer, tdifNodeTextVectorizer + """ + My_FeatureDefinition_v3_base.__init__(self) + + nbTypes = self._getTypeNumber(kwargs) + + # since we have a preprocessor, lowercase and strip_accents options are disabled + self._node_text_vectorizer = CountVectorizer( analyzer = 'char' + # AttributeError: Can't pickle local object 'My_FeatureDefinition_v3_txt.__init__..' + # , preprocessor = lambda x: x.lower().translate(self.TBL) + , preprocessor = My_FeatureDefinition_v3_txt_preprocess + , max_features = self.n_ngrams + , ngram_range = self.t_ngrams_range #(2,6) + , dtype=np.float64) + + block_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("text", Pipeline([ + ('selector', NodeTransformerText()) + , ('vecto', self._node_text_vectorizer) #we can use it separately from the pipleline once fitted + , ('todense', SparseToDense()) #pystruct needs an array, not a sparse matrix + ]) + ) + , + ("textlen", Pipeline([ + ('selector', NodeTransformerTextLen()), + ('textlen', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("xywh", Pipeline([ + ('selector', NodeTransformerXYWH_v2()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('xywh', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("edge_cnt", Pipeline([ + ('selector', SupportBlock_NodeTransformer()), + #v1 ('xywh', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('edge_cnt', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + , ("neighbors", Pipeline([ + ('selector', NodeTransformerNeighbors()), + #v1 ('neighbors', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('neighbors', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + , ("1hot", Pipeline([ + ('1hot', Node1HotFeatures_noText()) #does the 1-hot encoding directly + ]) + ) + ]) + + Cut_line_transformer = FeatureUnion( [ + ("std", CutLine_NodeTransformer_v3()) + , ("qty", Pipeline([ + ('selector', CutLine_NodeTransformer_qty()), + ('quantile', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + ]) + + self._node_transformer = TransformerListByType([block_transformer, Cut_line_transformer]) + + edge_BB_transformer = FeatureUnion( [ #CAREFUL IF YOU CHANGE THIS - see cleanTransformers method!!!! + ("1hot", Pipeline([ + ('1hot', Edge1HotFeatures_noText(PageNumberSimpleSequenciality())) + ]) + ) + , ("boolean", Pipeline([ + ('boolean', EdgeBooleanFeatures_v2()) + ]) + ) + , ("numerical", Pipeline([ + ('selector', EdgeNumericalSelector_v2()), + #v1 ('numerical', StandardScaler(copy=False, with_mean=True, with_std=True)) #use in-place scaling + ('numerical', QuantileTransformer(n_quantiles=self.n_QUANTILES, copy=False)) #use in-place scaling + ]) + ) + ] ) + #edge_BL_transformer = Block2CutLine_EdgeTransformer() + edge_BL_transformer = FeatureUnion( [ + ("std", Block2CutLine_EdgeTransformer()) + , ("qty", Pipeline([ + ('selector', Block2CutLine_EdgeTransformer_qtty()), + ('quantile', QuantileTransformer(n_quantiles=self.n_QUANTILES_sml, copy=False)) #use in-place scaling + ]) + ) + ]) + + edge_LL_transformer = CutLine2CutLine_EdgeTransformer() + self._edge_transformer = TransformerListByType([edge_BB_transformer, + edge_BL_transformer, + # edge_BL_transformer, # useless but required + Block2CutLine_FakeEdgeTransformer(), # fit is called with [], so the Pipeline explodes + edge_LL_transformer + ]) + + + def cleanTransformers(self): + """ + the TFIDF transformers are keeping the stop words => huge pickled file!!! + + Here the fix is a bit rough. There are better ways.... + JL + """ + self._node_transformer[0].transformer_list[0][1].steps[1][1].stop_words_ = None #is 1st in the union... +# for i in [2, 3, 4, 5, 6, 7]: +# self._edge_transformer.transformer_list[i][1].steps[1][1].stop_words_ = None #are 3rd and 4th in the union.... + return self._node_transformer, self._edge_transformer + + +def test_preprocess(capsys): + + with capsys.disabled(): + print("toto") + tbl = str.maketrans("0123456789", "NNNNNNNNNN") + fun = lambda x: x.lower().translate(tbl) + assert "abc" == fun("abc") + assert "abc" == fun("ABC") + assert "abcdé" == fun("ABCdé") + assert "tüv" == fun("tÜv") + assert "tüv NN " == fun("tÜv 12 ") + assert "" == fun("") + assert "N" == fun("1") + assert "NN" == fun("23") + assert "j't'aime moi non plus. dites NN!!" == fun("J't'aime MOI NON PlUs. Dites 33!!") + assert "" == fun("") + assert "" == fun("") + assert "" == fun("") + + +class NodeType_PageXml_Cut_Shape(NodeType_PageXml_type): + """ + we specialize it because our cuts are near horizontal + """ + def _iter_GraphNode(self, doc, domNdPage, page): + """ + Get the DOM, the DOM page node, the page object + + iterator on the DOM, that returns nodes (of class Block) + """ + #--- XPATH contexts + assert self.sxpNode, "CONFIG ERROR: need an xpath expression to enumerate elements corresponding to graph nodes" + lNdBlock = domNdPage.xpath(self.sxpNode, namespaces=self.dNS) #all relevant nodes of the page + + for ndBlock in lNdBlock: + domid = ndBlock.get("id") + sText = "" + + #now we need to infer the bounding box of that object + (x1, y1), (x2, y2) = PageXml.getPointList(ndBlock) #the polygon + + orientation = 0 + classIndex = 0 #is computed later on + + #and create a Block + # we pass the coordinates, not x1,y1,w,h !! + cutBlk = Block(page, ((x1, y1), (x2, y2)), sText, orientation, classIndex, self, ndBlock, domid=domid) + + # Create the shapely shape + cutBlk.shape = geom.LineString([(x1, y1), (x2, y2)]) + cutBlk.angle = float(ndBlock.get("DU_angle")) + cutBlk.angle_freq = float(ndBlock.get("DU_angle_freq")) + cutBlk.angle_cumul_freq = float(ndBlock.get("DU_angle_cumul_freq")) + cutBlk.set_support = literal_eval(ndBlock.get("DU_set_support")) + + yield cutBlk + + return + + +# ---------------------------------------------------------------------------- + +def main(TableSkewedRowCut_CLASS, sModelDir, sModelName, options): + """ + TableSkewedRowCut_CLASS must be a class inheriting from DU_Graph_CRF + """ + lDegAngle = [float(s) for s in options.lsAngle.split(",")] + lRadAngle = [math.radians(v) for v in lDegAngle] + + doer = TableSkewedRowCut_CLASS(sModelName, sModelDir, + iBlockVisibility = options.iBlockVisibility, + iLineVisibility = options.iLineVisibility, + fCutHeight = options.fCutHeight, + bCutAbove = options.bCutAbove, + lRadAngle = lRadAngle, + bTxt = options.bTxt, + C = options.crf_C, + tol = options.crf_tol, + njobs = options.crf_njobs, + max_iter = options.max_iter, + inference_cache = options.crf_inference_cache) + + if options.rm: + doer.rm() + return + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir, bAbsolute=False) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] +# if options.bAnnotate: +# doer.annotateDocument(lTrn) +# traceln('annotation done') +# sys.exit(0) + + + traceln("- classes: ", doer.getGraphClass().getLabelNameList()) + + ## use. a_mpxml files + #doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern + + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm, options.pkl) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None, options.pkl) + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm, options.pkl) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for CutSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + if options.bDetailedReport: + traceln(tstReport.getDetailledReport()) + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__detailled_report.txt") + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, tstReport) + + if lRun: + if options.storeX or options.applyY: + try: doer.load() + except: pass #we only need the transformer + lsOutputFilename = doer.runForExternalMLMethod(lRun, options.storeX, options.applyY, options.bRevertEdges) + else: + doer.load() + lsOutputFilename = doer.predict(lRun) + + traceln("Done, see in:\n %s"%lsOutputFilename) + + +def main_command_line(TableSkewedRowCut_CLASS): + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) +# parser.add_option("--annotate", dest='bAnnotate', action="store_true",default=False, help="Annotate the textlines with BIES labels") + + #FOR GCN + parser.add_option("--revertEdges", dest='bRevertEdges', action="store_true", help="Revert the direction of the edges") + parser.add_option("--detail", dest='bDetailedReport', action="store_true", default=False,help="Display detailed reporting (score per document)") + parser.add_option("--baseline", dest='bBaseline', action="store_true", default=False, help="report baseline method") + parser.add_option("--line_see_line", dest='iLineVisibility', action="store", + type=int, default=GraphSkewedCut.iLineVisibility, + help="seeline2line: how far in pixel can a line see another cut line?") + parser.add_option("--block_see_line", dest='iBlockVisibility', action="store", + type=int, default=GraphSkewedCut.iBlockVisibility, + help="seeblock2line: how far in pixel can a block see a cut line?") + parser.add_option("--height", dest="fCutHeight", default=GraphSkewedCut.fCutHeight + , action="store", type=float, help="Minimal height of a cut") + parser.add_option("--cut-above", dest='bCutAbove', action="store_true", default=False + ,help="Each object defines one or several cuts above it (instead of below as by default)") + parser.add_option("--angle", dest='lsAngle' + , action="store", type="string", default="-1,0,+1" + ,help="Allowed cutting angles, in degree, comma-separated") + + parser.add_option("--graph", dest='bGraph', action="store_true", help="Store the graph in the XML for displaying it") + + # --- + #parse the command line + (options, args) = parser.parse_args() + + if options.bGraph: + import os.path + # hack + TableSkewedRowCut_CLASS.bCutAbove = options.bCutAbove + traceln("\t%s.bCutAbove=" % TableSkewedRowCut_CLASS.__name__, TableSkewedRowCut_CLASS.bCutAbove) + TableSkewedRowCut_CLASS.lRadAngle = [math.radians(v) for v in [float(s) for s in options.lsAngle.split(",")]] + traceln("\t%s.lRadAngle=" % TableSkewedRowCut_CLASS.__name__, TableSkewedRowCut_CLASS.lRadAngle) + for sInputFilename in args: + sp, sf = os.path.split(sInputFilename) + sOutFilename = os.path.join(sp, "graph-" + sf) + doer = TableSkewedRowCut_CLASS("debug", "." + , iBlockVisibility=options.iBlockVisibility + , iLineVisibility=options.iLineVisibility + , fCutHeight=options.fCutHeight + , bCutAbove=options.bCutAbove + , lRadAngle=[math.radians(float(s)) for s in options.lsAngle.split(",")]) + o = doer.cGraphClass() + o.parseDocFile(sInputFilename, 9) + o.parseDocLabels() + o.addParsedLabelToDom() + o.addEdgaddEdgeToDoc print('Graph edges added to %s'%sOutFilename) + o.doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + SkewedCutAnnotator.gtStatReport() + exit(0) + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + main(TableSkewedRowCut_CLASS, sModelDir, sModelName, options) + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + from tasks.DU_ABPTableSkewed_txtBIO_sepSIO import DU_ABPTableSkewedRowCut + main_command_line(DU_ABPTableSkewedRowCut) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO.py index 213ec3d..13aa064 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO.py @@ -12,18 +12,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -59,13 +48,13 @@ class NodeType_BIESO_Shape(NodeType_PageXml_type_woText): """ """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'B', diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO_line_hack.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO_line_hack.py index 8a388d5..03d7567 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO_line_hack.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIESO_sepSIO_line_hack.py @@ -14,18 +14,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -87,13 +76,13 @@ class NodeType_BIESO_Shape(NodeType_PageXml_type_woText): """ """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'B', @@ -142,12 +131,12 @@ class GraphSkewedCut_H_lines(GraphSkewedCut_H): shaper_fun = ShapeLoader.node_to_SingleLine - def addEdgeToDOM(self): + def addEdgeToDoc(self): """ To display the grpah conveniently we add new Edge elements Since we change the BAseline representation, we show the new one """ - super().addEdgeToDOM() + super().addEdgeToDoc() for blk in self.lNode: assert blk.type.name in ["row", "sepH"], blk.type.name diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOH_sepSIO_line.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOH_sepSIO_line.py new file mode 100644 index 0000000..68db140 --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOH_sepSIO_line.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- + +""" + *** Same as DU_ABPTableSkewed_txtBIO_sepSIO_line, except that text have BIOH as labels + + DU task for ABP Table: + doing jointly row BIOH and near horizontal cuts SIO + + block2line edges do not cross another block. + + The cut are based on baselines of text blocks, with some positive or negative inclination. + + - the labels of cuts are SIO + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +from lxml import etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks.DU_CRF_Task import DU_CRF_Task +from tasks.DU_ABPTableSkewed import My_FeatureDefinition_v3, NodeType_PageXml_Cut_Shape, main_command_line +from tasks.DU_ABPTableSkewed_txtBIO_sepSIO import NodeType_BIESO_to_BIO_Shape_txt +from tasks.DU_ABPTableSkewed_txtBIO_sepSIO_line import GraphSkewedCut_H_lines, DU_ABPTableSkewedRowCutLine + + +class NodeType_BIESO_to_BIOH_Shape_txt(NodeType_BIESO_to_BIO_Shape_txt): + """ + Convert BIESO labeling to SIO + """ + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + sLabel = self.sDefaultLabel + domnode = graph_node.node + + sXmlLabel = domnode.get("DU_header") + if sXmlLabel != 'CH': + sXmlLabel = domnode.get(self.sLabelAttr) + + sXmlLabel = {'B':'B', + 'I':'I', + 'E':'I', + 'S':'B', + 'O':'O', + 'CH':'CH'}[sXmlLabel] + try: + sLabel = self.dXmlLabel2Label[sXmlLabel] + except KeyError: + #not a label of interest + try: + self.checkIsIgnored(sXmlLabel) + #if self.lsXmlIgnoredLabel and sXmlLabel not in self.lsXmlIgnoredLabel: + except: + raise ValueError("Invalid label '%s'" + " (from @%s or @%s) in node %s"%(sXmlLabel, + self.sLabelAttr, + self.sDefaultLabel, + etree.tostring(domnode))) + + return sLabel + + +class NodeType_BIESO_to_BIOH_Shape(NodeType_BIESO_to_BIOH_Shape_txt): + """ + without text + """ + def _get_GraphNodeText(self, doc, domNdPage, ndBlock, ctxt=None): + return u"" + + +class DU_ABPTableSkewedRowCutLine_BIOH(DU_ABPTableSkewedRowCutLine): + """ + We will do a CRF model for a DU task + , with the below labels + """ + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + + # Textline labels + # Begin Inside End Single Other + lLabels_BIOH_row = ['B', 'I', 'O', 'CH'] + + # Cut lines: + # Border Ignore Separator Outside + lLabels_SIO_Cut = ['S', 'I', 'O'] + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = GraphSkewedCut_H_lines + + DU_GRAPH.iBlockVisibility = cls.iBlockVisibility + DU_GRAPH.iLineVisibility = cls.iLineVisibility + DU_GRAPH.fCutHeight = cls.fCutHeight + DU_GRAPH.bCutAbove = cls.bCutAbove + DU_GRAPH.lRadAngle = cls.lRadAngle + DU_GRAPH.bTxt = cls.bTxt + + # ROW + ntR = ( NodeType_BIESO_to_BIOH_Shape_txt if cls.bTxt \ + else NodeType_BIESO_to_BIOH_Shape \ + )("row" + , lLabels_BIOH_row + , None + , False + , None + ) + ntR.setLabelAttribute("DU_row") + ntR.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv/pc:Unicode") #how to get their text + ) + DU_GRAPH.addNodeType(ntR) + + # CUT + ntCutH = NodeType_PageXml_Cut_Shape("sepH" + , lLabels_SIO_Cut + , None + , False + , None # equiv. to: BBoxDeltaFun=lambda _: 0 + ) + ntCutH.setLabelAttribute("DU_type") + ntCutH.setXpathExpr( ('.//pc:CutSeparator[@orient="0"]' #how to find the nodes + # the angle attribute give the true orientation (which is near 0) + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(ntCutH) + + DU_GRAPH.setClassicNodeTypeList( [ntR ]) + + return DU_GRAPH + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + main_command_line(DU_ABPTableSkewedRowCutLine_BIOH) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line.py index e18c70f..9f5c119 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line.py @@ -12,18 +12,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -67,11 +56,12 @@ class NodeType_BIESO_to_SIOStSmSb_Shape(NodeType_BIESO_to_BIO_Shape): 'O':'O', 'CH':'CH'} - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) # in case we also deal with column headers diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line_hack.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line_hack.py index fcc45ac..74e4f10 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line_hack.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIOStmb_sepSIO_line_hack.py @@ -12,18 +12,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -96,11 +85,12 @@ class NodeType_BIESO_to_SIOStSmSb_Shape(NodeType_BIESO_to_BIO_Shape): 'O':'O', 'CH':'CH'} - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) # in case we also deal with column headers diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO.py new file mode 100644 index 0000000..1de999b --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- + +""" + DU task for ABP Table: + doing jointly row BIO and near horizontal cuts SIO + + block2line edges do not cross another block. + + The cut are based on baselines of text blocks, with some positive or negative inclination. + + - the labels of cuts are SIO + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os +from lxml import etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks.DU_CRF_Task import DU_CRF_Task + +from util.Shape import ShapeLoader + +from tasks.DU_ABPTableSkewed import GraphSkewedCut_H, My_FeatureDefinition_v3, NodeType_PageXml_Cut_Shape, main_command_line +from graph.NodeType_PageXml import NodeType_PageXml_type + + + +# class NodeType_BIESO_to_BIO_Shape(NodeType_PageXml_type_woText): +class NodeType_BIESO_to_BIO_Shape_txt(NodeType_PageXml_type): + """ + Convert BIESO labeling to BIO + """ + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + sLabel = self.sDefaultLabel + domnode = graph_node.node + sXmlLabel = domnode.get(self.sLabelAttr) + + sXmlLabel = {'B':'B', + 'I':'I', + 'E':'I', + 'S':'B', + 'O':'O'}[sXmlLabel] + try: + sLabel = self.dXmlLabel2Label[sXmlLabel] + except KeyError: + #not a label of interest + try: + self.checkIsIgnored(sXmlLabel) + #if self.lsXmlIgnoredLabel and sXmlLabel not in self.lsXmlIgnoredLabel: + except: + raise ValueError("Invalid label '%s'" + " (from @%s or @%s) in node %s"%(sXmlLabel, + self.sLabelAttr, + self.sDefaultLabel, + etree.tostring(domnode))) + + return sLabel + + def _iter_GraphNode(self, doc, domNdPage, page): + """ + to add the shape object reflecting the baseline + """ + for blk in super()._iter_GraphNode(doc, domNdPage, page): + try: + ndBaseline = blk.node.xpath(".//pc:Baseline", namespaces=self.dNS)[0] + try: + o = ShapeLoader.node_to_LineString(ndBaseline) + except ValueError: + traceln("SKIPPING INVALID Baseline: ", etree.tostring(ndBaseline)) + continue + blk.shape = o + blk.du_index = int(ndBaseline.get("du_index")) + yield blk + except: + pass + return + + +class NodeType_BIESO_to_BIO_Shape(NodeType_BIESO_to_BIO_Shape_txt): + """ + without text + """ + def _get_GraphNodeText(self, doc, domNdPage, ndBlock, ctxt=None): + return u"" + + +class DU_ABPTableSkewedRowCut(DU_CRF_Task): + """ + We will do a CRF model for a DU task + , with the below labels + """ + sXmlFilenamePattern = "*[0-9].mpxml" + + iBlockVisibility = None + iLineVisibility = None + fCutHeight = None + bCutAbove = None + lRadAngle = None + bTxt = None # use textual features? + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + + # Textline labels + # Begin Inside End Single Other + lLabels_BIO_row = ['B', 'I', 'O'] + + # Cut lines: + # Border Ignore Separator Outside + lLabels_SIO_Cut = ['S', 'I', 'O'] + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = GraphSkewedCut_H + + DU_GRAPH.iBlockVisibility = cls.iBlockVisibility + DU_GRAPH.iLineVisibility = cls.iLineVisibility + DU_GRAPH.fCutHeight = cls.fCutHeight + DU_GRAPH.bCutAbove = cls.bCutAbove + DU_GRAPH.lRadAngle = cls.lRadAngle + DU_GRAPH.bTxt = cls.bTxt + + # ROW + ntR = ( NodeType_BIESO_to_BIO_Shape_txt if cls.bTxt \ + else NodeType_BIESO_to_BIO_Shape \ + )("row" + , lLabels_BIO_row + , None + , False + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) + ) + ntR.setLabelAttribute("DU_row") + ntR.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv/pc:Unicode") #how to get their text + ) + DU_GRAPH.addNodeType(ntR) + + # CUT + ntCutH = NodeType_PageXml_Cut_Shape("sepH" + , lLabels_SIO_Cut + , None + , False + , None # equiv. to: BBoxDeltaFun=lambda _: 0 + ) + ntCutH.setLabelAttribute("DU_type") + ntCutH.setXpathExpr( ('.//pc:CutSeparator[@orient="0"]' #how to find the nodes + # the angle attribute give the true orientation (which is near 0) + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(ntCutH) + + DU_GRAPH.setClassicNodeTypeList( [ntR ]) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, + iBlockVisibility = None, + iLineVisibility = None, + fCutHeight = None, + bCutAbove = None, + lRadAngle = None, + bTxt = None, + sComment = None, + C=None, tol=None, njobs=None, max_iter=None, + inference_cache=None): + + DU_ABPTableSkewedRowCut.iBlockVisibility = iBlockVisibility + DU_ABPTableSkewedRowCut.iLineVisibility = iLineVisibility + DU_ABPTableSkewedRowCut.fCutHeight = fCutHeight + DU_ABPTableSkewedRowCut.bCutAbove = bCutAbove + DU_ABPTableSkewedRowCut.lRadAngle = lRadAngle + DU_ABPTableSkewedRowCut.bTxt = bTxt + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig = {'row_row':{}, 'row_sepH':{}, + 'sepH_row':{}, 'sepH_sepH':{}, + 'sepH':{}, 'row':{}} + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 4 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 10 if max_iter is None else max_iter + } + , sComment=sComment + #,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText + ,cFeatureDefinition=My_FeatureDefinition_v3 + ) + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + main_command_line(DU_ABPTableSkewedRowCut) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line.py new file mode 100644 index 0000000..41d20b7 --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +""" + *** Same as its parent apart that text baselines are reflected as a LineString (instead of its centroid) + + DU task for ABP Table: + doing jointly row BIO and near horizontal cuts SIO + + block2line edges do not cross another block. + + The cut are based on baselines of text blocks, with some positive or negative inclination. + + - the labels of cuts are SIO + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import MultiPageXml +from util.Shape import ShapeLoader + +from tasks.DU_CRF_Task import DU_CRF_Task +from tasks.DU_ABPTableSkewed import GraphSkewedCut_H, My_FeatureDefinition_v3, NodeType_PageXml_Cut_Shape, main_command_line,\ + My_FeatureDefinition_v3_txt +from tasks.DU_ABPTableSkewed_txtBIO_sepSIO import NodeType_BIESO_to_BIO_Shape, NodeType_BIESO_to_BIO_Shape_txt + + +class GraphSkewedCut_H_lines(GraphSkewedCut_H): + + # reflecting text baseline as a LineString + shaper_fun = ShapeLoader.node_to_SingleLine + + + def addEdgeToDoc(self, Y=None): + """ + To display the grpah conveniently we add new Edge elements + Since we change the BAseline representation, we show the new one + """ + super().addEdgeToDoc() + + for blk in self.lNode: + assert blk.type.name in ["row", "sepH"], blk.type.name + + if blk.type.name == "row": + ndBaseline = blk.node.xpath(".//pc:Baseline", namespaces=self.dNS)[0] + o = self.shaper_fun(ndBaseline) + MultiPageXml.setPoints(ndBaseline, list(o.coords)) + + return + + +class DU_ABPTableSkewedRowCutLine(DU_CRF_Task): + """ + We will do a CRF model for a DU task + , with the below labels + """ + sXmlFilenamePattern = "*.mpxml" + #sXmlFilenamePattern = "*.pxml" + + iBlockVisibility = None + iLineVisibility = None + fCutHeight = None + bCutAbove = None + lRadAngle = None + bTxt = None # use textual features? + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + + # Textline labels + # Begin Inside End Single Other + lLabels_BIO_row = ['B', 'I', 'O'] + + # Cut lines: + # Border Ignore Separator Outside + lLabels_SIO_Cut = ['S', 'I', 'O'] + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = GraphSkewedCut_H_lines + + DU_GRAPH.iBlockVisibility = cls.iBlockVisibility + DU_GRAPH.iLineVisibility = cls.iLineVisibility + DU_GRAPH.fCutHeight = cls.fCutHeight + DU_GRAPH.bCutAbove = cls.bCutAbove + DU_GRAPH.lRadAngle = cls.lRadAngle + DU_GRAPH.bTxt = cls.bTxt + + # ROW + ntR = ( NodeType_BIESO_to_BIO_Shape_txt if cls.bTxt \ + else NodeType_BIESO_to_BIO_Shape \ + )("row" + , lLabels_BIO_row + , None + , False + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) + ) + ntR.setLabelAttribute("DU_row") + ntR.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv/pc:Unicode") #how to get their text + ) + DU_GRAPH.addNodeType(ntR) + + # CUT + ntCutH = NodeType_PageXml_Cut_Shape("sepH" + , lLabels_SIO_Cut + , None + , False + , None # equiv. to: BBoxDeltaFun=lambda _: 0 + ) + ntCutH.setLabelAttribute("DU_type") + ntCutH.setXpathExpr( ('.//pc:CutSeparator[@orient="0"]' #how to find the nodes + # the angle attribute give the true orientation (which is near 0) + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(ntCutH) + + DU_GRAPH.setClassicNodeTypeList( [ntR ]) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir + , iBlockVisibility = None + , iLineVisibility = None + , fCutHeight = None + , bCutAbove = None + , lRadAngle = None + , bTxt = None + , sComment = None + , cFeatureDefinition = None + , dFeatureConfig = {} + , C=None, tol=None, njobs=None, max_iter=None + , inference_cache=None): + + DU_ABPTableSkewedRowCutLine.iBlockVisibility = iBlockVisibility + DU_ABPTableSkewedRowCutLine.iLineVisibility = iLineVisibility + DU_ABPTableSkewedRowCutLine.fCutHeight = fCutHeight + DU_ABPTableSkewedRowCutLine.bCutAbove = bCutAbove + DU_ABPTableSkewedRowCutLine.lRadAngle = lRadAngle + DU_ABPTableSkewedRowCutLine.bTxt = bTxt + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig = {'row_row':{}, 'row_sepH':{}, + 'sepH_row':{}, 'sepH_sepH':{}, + 'sepH':{}, 'row':{}} + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 4 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 10 if max_iter is None else max_iter + } + , sComment=sComment + #,cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText + , cFeatureDefinition= My_FeatureDefinition_v3_txt if self.bTxt else My_FeatureDefinition_v3 + ) + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + main_command_line(DU_ABPTableSkewedRowCutLine) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line_weighted.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line_weighted.py index 1b38f84..cfc4239 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line_weighted.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_line_weighted.py @@ -14,18 +14,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_weighted.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_weighted.py index e538721..ca1b4af 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_weighted.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBIO_sepSIO_weighted.py @@ -12,18 +12,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -62,13 +51,13 @@ class NodeType_BIESO_to_BIO_Shape(NodeType_PageXml_type_woText): Convert BIESO labeling to BIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel - + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) sXmlLabel = {'B':'B', diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line.py index 885d7dc..b7d2511 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line.py @@ -14,18 +14,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -60,12 +49,13 @@ class NodeType_BIESO_to_BISO_Shape(NodeType_PageXml_type_woText): Convert BIESO labeling to BIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) @@ -116,12 +106,12 @@ class GraphSkewedCut_H_lines(GraphSkewedCut_H): shaper_fun = ShapeLoader.node_to_SingleLine - def addEdgeToDOM(self): + def addEdgeToDoc(self): """ To display the grpah conveniently we add new Edge elements Since we change the BAseline representation, we show the new one """ - super().addEdgeToDOM() + super().addEdgeToDoc() for blk in self.lNode: assert blk.type.name in ["row", "sepH"], blk.type.name diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line_hack.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line_hack.py index e683e04..7148305 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line_hack.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtBISO_sepSIO_line_hack.py @@ -14,18 +14,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -87,11 +76,12 @@ class NodeType_BISO_Shape(NodeType_PageXml_type_woText): """ """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ + domnode = graph_node.node sLabel = self.sDefaultLabel sXmlLabel = domnode.get(self.sLabelAttr) @@ -142,12 +132,12 @@ class GraphSkewedCut_H_lines(GraphSkewedCut_H): shaper_fun = ShapeLoader.node_to_SingleLine - def addEdgeToDOM(self): + def addEdgeToDoc(self): """ To display the grpah conveniently we add new Edge elements Since we change the BAseline representation, we show the new one """ - super().addEdgeToDOM() + super().addEdgeToDoc() for blk in self.lNode: assert blk.type.name in ["row", "sepH"], blk.type.name diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtEIO_sepSIO.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtEIO_sepSIO.py index e252c9d..f253fe5 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtEIO_sepSIO.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtEIO_sepSIO.py @@ -12,18 +12,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -60,12 +49,13 @@ class NodeType_BIESO_to_EIO_Shape(NodeType_PageXml_type_woText): Convert BIESO labeling to EIO """ - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ sLabel = self.sDefaultLabel + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line.py index 0c18258..80cf9b7 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line.py @@ -10,18 +10,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -112,12 +101,12 @@ def showClassParam(cls): traceln(" - iCutCloseDistanceTop : " , cls.iCutCloseDistanceTop) traceln(" - iCutCloseDistanceBot : " , cls.iCutCloseDistanceBot) - def addEdgeToDOM(self): + def addEdgeToDoc(self): """ To display the grpah conveniently we add new Edge elements Since we change the BAseline representation, we show the new one """ - super().addEdgeToDOM() + super().addEdgeToDoc() for blk in self.lNode: assert blk.type.name in ["row", "sepH"], blk.type.name @@ -132,13 +121,13 @@ def addEdgeToDOM(self): """ To compute TOMBS labels, it is better to use the built graph... """ - def parseDomLabels(self): + def parseDocLabels(self): """ Parse the label of the graph from the dataset, and set the node label return the set of observed class (set of integers in N+) """ # WE expect I or O for text blocks!! - setSeensLabels = super().parseDomLabels() + setSeensLabels = super().parseDocLabels() # now look at edges to compute T M B S # REMEMBER, we did: edge.len = dist / self.iBlockVisibility @@ -193,11 +182,12 @@ class NodeType_BIESO_to_TOMBS_Shape(NodeType_BIESO_to_BIO_Shape): 'O':'O', 'CH':'CH'} - def parseDomNodeLabel(self, domnode, defaultCls=None): + def parseDocNodeLabel(self, graph_node, defaultCls=None): """ Parse and set the graph node label and return its class index raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one """ + domnode = graph_node.node sXmlLabel = domnode.get(self.sLabelAttr) # in case we also deal with column headers diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line_hack.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line_hack.py index 5400c20..faf80f4 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line_hack.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTableSkewed_txtTOMBS_sepSIO_line_hack.py @@ -10,18 +10,7 @@ Copyright Naver Labs Europe(C) 2018 JL Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoEF.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoEF.py index 74c1bff..d9512bb 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoEF.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoEF.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2017 H. Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoNF.py b/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoNF.py index 014828e..fd28ea0 100644 --- a/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoNF.py +++ b/TranskribusDU/tasks/TablePrototypes/DU_ABPTable_Quantile_NoNF.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2017 H. Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/TablePrototypes/DU_Table_BIO.py b/TranskribusDU/tasks/TablePrototypes/DU_Table_BIO.py new file mode 100644 index 0000000..b211bf1 --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_Table_BIO.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- + +""" + DU task for segmenting text in rows using a BIO scheme + + Example of code after April SW re-engineering by JLM + + Copyright NAVER(C) 2019 Jean-Luc Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os +import lxml.etree as etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from graph.NodeType_PageXml import NodeType_PageXml_type_woText +from graph.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +from tasks.DU_Task_Factory import DU_Task_Factory + + + +# to convert from BIESO to BIO we create our own NodeType by inheritance +# class NodeType_BIESO_to_BIO_Shape(NodeType_PageXml_type_woText): +class NodeType_PageXml_type_woText_BIESO_to_BIO(NodeType_PageXml_type_woText): + """ + Convert BIESO labeling to BIO + """ + + def parseDocNodeLabel(self, graph_node, defaultCls=None): + """ + Parse and set the graph node label and return its class index + raise a ValueError if the label is missing while bOther was not True, or if the label is neither a valid one nor an ignored one + """ + sLabel = self.sDefaultLabel + domnode = graph_node.node + + sXmlLabel = domnode.get(self.sLabelAttr) + + sXmlLabel = {'B':'B', + 'I':'I', + 'E':'I', + 'S':'B', + 'O':'O'}[sXmlLabel] + try: + sLabel = self.dXmlLabel2Label[sXmlLabel] + except KeyError: + #not a label of interest + try: + self.checkIsIgnored(sXmlLabel) + #if self.lsXmlIgnoredLabel and sXmlLabel not in self.lsXmlIgnoredLabel: + except: + raise ValueError("Invalid label '%s'" + " (from @%s or @%s) in node %s"%(sXmlLabel, + self.sLabelAttr, + self.sDefaultLabel, + etree.tostring(domnode))) + + return sLabel + + +def getConfiguredGraphClass(doer): + """ + In this function, we return a configured graph.Graph subclass + + doer is a tasks.DU_task object created by tasks.DU_Task_Factory + """ + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiSinglePageXml + + lLabels = ['B', 'I', 'O'] + + lIgnoredLabels = [] + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen ] + print( len(lLabels) , lLabels) + print( len(lIgnoredLabels) , lIgnoredLabels) + + nt = NodeType_PageXml_type_woText_BIESO_to_BIO( + "abp" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt.setLabelAttribute("DU_row") + + nt.setXpathExpr( (".//pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + + # ntA.setXpathExpr( (".//pc:TextLine | .//pc:TextRegion" #how to find the nodes + # , "./pc:TextEquiv") #how to get their text + # ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + # standard command line options for CRF- ECN- GAT-based methods + usage, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0]) + + traceln("VERSION: %s" % DU_Task_Factory.getVersion()) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + DU_Task_Factory.exit(usage, 1, e) + + doer = DU_Task_Factory.getDoer(sModelDir, sModelName + , options = options + , fun_getConfiguredGraphClass= getConfiguredGraphClass + , cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + , dFeatureConfig = {} + ) + + # setting the learner configuration, in a standard way + # (from command line options, or from a JSON configuration file) + dLearnerConfig = doer.getStandardLearnerConfig(options) + # of course, you can put yours here instead. + doer.setLearnerConfiguration(dLearnerConfig) + + + # act as per specified in the command line (--trn , --fold-run, ...) + doer.standardDo(options) + + del doer + diff --git a/TranskribusDU/tasks/TablePrototypes/DU_Table_Col.py b/TranskribusDU/tasks/TablePrototypes/DU_Table_Col.py new file mode 100644 index 0000000..8813232 --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_Table_Col.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +""" + Create column segmenters + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os +from optparse import OptionParser + +from lxml import etree + + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _exit +from tasks.DU_ABPTableCutAnnotator import CutAnnotator + + +def main(sFilename, sOutFilename + , fRatio, fMinHLen + , fMinHorizProjection, fMinVertiProjection + ): + + traceln("- cutting: %s --> %s"%(sFilename, sOutFilename)) + + #for the pretty printer to format better... + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + root=doc.getroot() + + doer = CutAnnotator() + + # # Some grid line will be O or I simply because they are too short. + # fMinPageCoverage = 0.5 # minimum proportion of the page crossed by a grid line + # # we want to ignore col- and row- spans + #map the groundtruth table separators to our grid, per page (1 in tABP) + # ltlYlX = doer.get_separator_YX_from_DOM(root, fMinPageCoverage) + + # Find cuts and map them to GT + # + doer.add_cut_to_DOM(root + #, ltlYlX=ltlYlX + , fMinHorizProjection=fMinHorizProjection + , fMinVertiProjection=fMinVertiProjection + , fRatio=fRatio + , fMinHLen=fMinHLen) + + #l_DU_row_Y, l_DU_row_GT = doer.predict(root) + + doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + traceln('Annotated cut separators added into %s'%sOutFilename) + + del doc + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + usage = """+| """ + version = "v.01" + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--ratio", dest='fRatio', action="store" + , type=float + , help="Apply this ratio to the bounding box" + , default=CutAnnotator.fRATIO) + parser.add_option("--fMinHLen", dest='fMinHLen', action="store" + , type=float + , help="Do not scale horizontally a bounding box with width lower than this" + , default=75) + + parser.add_option("--fHorizRatio", dest='fMinHorizProjection', action="store" + , type=float + , help="On the horizontal projection profile, it ignores profile lower than this ratio of the page width" + , default=0.05) + parser.add_option("--fVertRatio", dest='fMinVertiProjection', action="store" + , type=float + , help="On the vertical projection profile, it ignores profile lower than this ratio of the page height" + , default=0.05) +# parser.add_option("--SIO" , dest='bSIO' , action="store_true", help="SIO labels") +# parser.add_option("--annotate", dest='bAnnotate', action="store_true",default=False, help="Annotate the textlines with BIES labels") + +# parser.add_option("--detail", dest='bDetailedReport', action="store_true", default=False,help="Display detailed reporting (score per document)") +# parser.add_option("--baseline", dest='bBaseline', action="store_true", default=False, help="report baseline method") +# parser.add_option("--line_see_line", dest='iLineVisibility', action="store", +# type=int, default=GraphSkewedCut.iLineVisibility, +# help="seeline2line: how far in pixel can a line see another cut line?") +# parser.add_option("--block_see_line", dest='iBlockVisibility', action="store", +# type=int, default=GraphSkewedCut.iBlockVisibility, +# help="seeblock2line: how far in pixel can a block see a cut line?") +# parser.add_option("--height", dest="fCutHeight", default=GraphSkewedCut.fCutHeight +# , action="store", type=float, help="Minimal height of a cut") +# # parser.add_option("--cut-above", dest='bCutAbove', action="store_true", default=False +# # ,help="Each object defines one or several cuts above it (instead of below as by default)") +# parser.add_option("--angle", dest='lsAngle' +# , action="store", type="string", default="-1,0,+1" +# ,help="Allowed cutting angles, in degree, comma-separated") +# +# parser.add_option("--graph", dest='bGraph', action="store_true", help="Store the graph in the XML for displaying it") +# parser.add_option("--bioh", "--BIOH", dest='bBIOH', action="store_true", help="Text are categorised along BIOH instead of BIO") + + # --- + #parse the command line + (options, args) = parser.parse_args() + + traceln(options) + + if len(args) == 2 and os.path.isdir(args[0]) and os.path.isdir(args[1]): + # ok, let's work differently... + sFromDir,sToDir = args + for s in os.listdir(sFromDir): + if not s.lower().endswith("pxml"): pass + sFilename = sFromDir + "/" + s + sp, sf = os.path.split(s) + sOutFilename = sToDir + "/" + "cut-" + sf + traceln(sFilename," --> ", sOutFilename) + main(sFilename, sOutFilename + , options.fRatio, fMinHLen=options.fMinHLen + , fMinHorizProjection=options.fMinHorizProjection + , fMinVertiProjection=options.fMinVertiProjection + ) + else: + for sFilename in args: + sp, sf = os.path.split(sFilename) + sOutFilename = os.path.join(sp, "cut-" + sf) + traceln(sFilename," --> ", sOutFilename) + main(sFilename, sOutFilename + , options.fRatio, fMinHLen=options.fMinHLen + , fMinHorizProjection=options.fMinHorizProjection + , fMinVertiProjection=options.fMinVertiProjection + ) + diff --git a/TranskribusDU/tasks/TablePrototypes/DU_Table_Row.py b/TranskribusDU/tasks/TablePrototypes/DU_Table_Row.py new file mode 100644 index 0000000..c69734c --- /dev/null +++ b/TranskribusDU/tasks/TablePrototypes/DU_Table_Row.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +""" + *** Same as its parent apart that text baselines are reflected as a LineString (instead of its centroid) + + DU task for ABP Table: + doing jointly row BIO and near horizontal cuts SIO + + block2line edges do not cross another block. + + The cut are based on baselines of text blocks, with some positive or negative inclination. + + - the labels of cuts are SIO + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +import math + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln +from tasks import _exit +from tasks.DU_CRF_Task import DU_CRF_Task +from tasks.DU_Table.DU_ABPTableSkewed import GraphSkewedCut, main +from tasks.DU_Table.DU_ABPTableSkewed_CutAnnotator import SkewedCutAnnotator +from tasks.DU_Table.DU_ABPTableSkewed_txtBIO_sepSIO_line import DU_ABPTableSkewedRowCutLine +from tasks.DU_Table.DU_ABPTableSkewed_txtBIOH_sepSIO_line import DU_ABPTableSkewedRowCutLine_BIOH + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) +# parser.add_option("--annotate", dest='bAnnotate', action="store_true",default=False, help="Annotate the textlines with BIES labels") + + #FOR GCN + # parser.add_option("--revertEdges", dest='bRevertEdges', action="store_true", help="Revert the direction of the edges") + parser.add_option("--detail", dest='bDetailedReport', action="store_true", default=False,help="Display detailed reporting (score per document)") + parser.add_option("--baseline", dest='bBaseline', action="store_true", default=False, help="report baseline method") + parser.add_option("--line_see_line", dest='iLineVisibility', action="store", + type=int, default=GraphSkewedCut.iLineVisibility, + help="seeline2line: how far in pixel can a line see another cut line?") + parser.add_option("--block_see_line", dest='iBlockVisibility', action="store", + type=int, default=GraphSkewedCut.iBlockVisibility, + help="seeblock2line: how far in pixel can a block see a cut line?") + parser.add_option("--height", dest="fCutHeight", default=GraphSkewedCut.fCutHeight + , action="store", type=float, help="Minimal height of a cut") + # parser.add_option("--cut-above", dest='bCutAbove', action="store_true", default=False + # ,help="Each object defines one or several cuts above it (instead of below as by default)") + parser.add_option("--angle", dest='lsAngle' + , action="store", type="string", default="-1,0,+1" + ,help="Allowed cutting angles, in degree, comma-separated") + + parser.add_option("--graph", dest='bGraph', action="store_true", help="Store the graph in the XML for displaying it") + parser.add_option("--bioh", "--BIOH", dest='bBIOH', action="store_true", help="Text are categorised along BIOH instead of BIO") + parser.add_option("--text", "--txt", dest='bTxt', action="store_true", help="Use textual features.") + + # --- + #parse the command line + (options, args) = parser.parse_args() + + options.bCutAbove = True # Forcing this! + + if options.bBIOH: + DU_CLASS = DU_ABPTableSkewedRowCutLine_BIOH + else: + DU_CLASS = DU_ABPTableSkewedRowCutLine + + if options.bGraph: + import os.path + # hack + DU_CLASS.bCutAbove = options.bCutAbove + traceln("\t%s.bCutAbove=" % DU_CLASS.__name__, DU_CLASS.bCutAbove) + DU_CLASS.lRadAngle = [math.radians(v) for v in [float(s) for s in options.lsAngle.split(",")]] + traceln("\t%s.lRadAngle=" % DU_CLASS.__name__, DU_CLASS.lRadAngle) + for sInputFilename in args: + sp, sf = os.path.split(sInputFilename) + sOutFilename = os.path.join(sp, "graph-" + sf) + doer = DU_CLASS("debug", "." + , iBlockVisibility=options.iBlockVisibility + , iLineVisibility=options.iLineVisibility + , fCutHeight=options.fCutHeight + , bCutAbove=options.bCutAbove + , lRadAngle=[math.radians(float(s)) for s in options.lsAngle.split(",")] + , bTxt=options.bTxt) + o = doer.cGraphClass() + o.parseDocFile(sInputFilename, 9) + o.addEdgeToDoc() + print('Graph edges added to %s'%sOutFilename) + o.doc.write(sOutFilename, encoding='utf-8',pretty_print=True,xml_declaration=True) + SkewedCutAnnotator.gtStatReport() + exit(0) + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + main(DU_CLASS, sModelDir, sModelName, options) diff --git a/TranskribusDU/tasks/case_BAR/DU_BAR.py b/TranskribusDU/tasks/case_BAR/DU_BAR.py new file mode 100644 index 0000000..ed4c150 --- /dev/null +++ b/TranskribusDU/tasks/case_BAR/DU_BAR.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +""" + DU task for BAR - see https://read02.uibk.ac.at/wiki/index.php/Document_Understanding_BAR + + Copyright Xerox(C) 2017 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _checkFindColDir, _exit + +from tasks.DU_CRF_Task import DU_CRF_Task + + +def main(DU_BAR): + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) + parser.add_option("--docid", dest='docid', action="store",default=None, help="only process docid") + # --- + #parse the command line + (options, args) = parser.parse_args() + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + doer = DU_BAR(sModelName, sModelDir, + C = options.crf_C, + tol = options.crf_tol, + njobs = options.crf_njobs, + max_iter = options.max_iter, + inference_cache = options.crf_inference_cache) + + + if options.docid: + sDocId=options.docid + else: + sDocId=None + if options.rm: + doer.rm() + sys.exit(0) + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] +# if options.bAnnotate: +# doer.annotateDocument(lTrn) +# traceln('annotation done') +# sys.exit(0) + + ## use. a_mpxml files + doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern + + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None, options.pkl) + import graph.GraphModel + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm, options.pkl) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for GridSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + + if lRun: + if options.storeX or options.applyY: + try: doer.load() + except: pass #we only need the transformer + lsOutputFilename = doer.runForExternalMLMethod(lRun, options.storeX, options.applyY) + else: + doer.load() + lsOutputFilename = doer.predict(lRun) + traceln("Done, see in:\n %s"%lsOutputFilename) + +if __name__ == "__main__": + raise Exception("This is an abstract module.") \ No newline at end of file diff --git a/TranskribusDU/tasks/case_BAR/DU_BAR_ConvertGTAnnotation.py b/TranskribusDU/tasks/case_BAR/DU_BAR_ConvertGTAnnotation.py new file mode 100644 index 0000000..18c8f9f --- /dev/null +++ b/TranskribusDU/tasks/case_BAR/DU_BAR_ConvertGTAnnotation.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- + +""" + DU task for BAR documents - see https://read02.uibk.ac.at/wiki/index.php/Document_Understanding_BAR + + Here we convert the human annotation into 2 kinds of annotations: + - a semantic one: header, heading, page-number, resolution-marginalia, resolution-number, resolution-paragraph (we ignore Marginalia because only 2 occurences) + - a segmentation one: 2 complementary labels. We call them Heigh Ho. Could have been Yin Yang as well... + - also, we store the resolution number in @DU_num + + These new annotations are stored in @DU_sem , @DU_sgm , @DU_num + + Copyright Naver Labs(C) 2017 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os, re + +from lxml import etree +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + + +from xml_formats.PageXml import PageXml, MultiPageXml, PageXmlException +from crf.Graph_MultiPageXml import Graph_MultiPageXml +from util.Polygon import Polygon + + +class DU_BAR_Convert: + """ + Here we convert the human annotation into 2 kinds of annotations: + - a semantic one: header, heading, page-number, resolution-marginalia, resolution-number, resolution-paragraph (we ignore Marginalia because only 2 occurences) + - a segmentation one: 2 complementary labels. We call them Heigh Ho. Could have been Yin Yang as well... + + These new annotations are store in @DU_sem and @DU_sgm + """ + sXml_HumanAnnotation_Extension = ".mpxml" + sXml_MachineAnnotation_Extension = ".du_mpxml" + + sMetadata_Creator = "TranskribusDU/usecases/BAR/DU_ConvertGTAnnotation.py" + sMetadata_Comments = "Converted human annotation into semantic and segmentation annotation. See attributes @DU_sem and @DU_sgm." + + dNS = {"pc":PageXml.NS_PAGE_XML} + sxpNode = ".//pc:TextRegion" + + #Name of attributes for semantic / segmentation /resolution number + sSemAttr = "DU_sem" + sSgmAttr = "DU_sgm" + sNumAttr = "DU_num" + + sOther = "other" + + #Mapping to new semantic annotation + dAnnotMapping = {"header" :"header", + "heading" :"heading", + "page-number" :"page-number", + "marginalia" : sOther, + "p" :"resolution-paragraph", + "m" :"resolution-marginalia", + "" :"resolution-number", + None : sOther #for strange things + } + creResolutionHumanLabel = re.compile("([mp]?)([0-9]+.?)") #e.g. p1 m23 456 456a + + #The two complementary segmentation labels + sSegmHeigh = "heigh" + sSegmHo = "ho" + + #=== CONFIGURATION ==================================================================== + def __init__(self): + + pass + + + def convertDoc(self, sFilename): + + assert sFilename.endswith(self.sXml_HumanAnnotation_Extension) + + g = Graph_MultiPageXml() + + doc = etree.parse(sFilename, encoding='utf-8') + + #the Heigh/Ho annotation runs over consecutive pages, so we keep those values accross pages + self._initSegmentationLabel() + self.lSeenResoNum = list() + + for pnum, page, domNdPage in g._iter_Page_DocNode(doc): + self._convertPageAnnotation(pnum, page, domNdPage) + + MultiPageXml.setMetadata(doc, None, self.sMetadata_Creator, self.sMetadata_Comments) + + assert sFilename.endswith(self.sXml_HumanAnnotation_Extension) + + sDUFilename = sFilename[:-len(self.sXml_HumanAnnotation_Extension)] + self.sXml_MachineAnnotation_Extension +# doc.save(sDUFilename, encoding='utf-8', pretty_print=True) + doc.write(sDUFilename, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + +# doc.saveFormatFileEnc(sDUFilename, "utf-8", True) #True to indent the XML +# doc.freeDoc() + + return sDUFilename + + # ----------------------------------------------------------------------------------------------------------- + + def _initSegmentationLabel(self): + self.prevResolutionNumber, self.prevSgmLbl = None, None + + def _getNextSegmentationLabel(self, sPrevSegmLabel=None): + """ + alternate beween HEIGH and HO, 1st at random + """ + if sPrevSegmLabel == self.sSegmHeigh: return self.sSegmHo + elif sPrevSegmLabel == self.sSegmHo: return self.sSegmHeigh + else: + assert sPrevSegmLabel == None + return self.sSegmHeigh + + def _iter_TextRegionNodeTop2Bottom(self, domNdPage, page): + """ + Get the DOM, the DOM page node, the page object + + iterator on the DOM, that returns nodes + """ + assert self.sxpNode, "CONFIG ERROR: need an xpath expression to enumerate elements corresponding to graph nodes" + lNdBlock = domNdPage.xpath(self.sxpNode, namespaces=self.dNS) + + #order blocks from top to bottom of page + lOrderedNdBlock = list() + for ndBlock in lNdBlock: + + lXY = PageXml.getPointList(ndBlock) #the polygon + if lXY == []: + raise ValueError("Node %x has invalid coordinates" % str(ndBlock)) + + plg = Polygon(lXY) + _, (xg, yg) = plg.getArea_and_CenterOfMass() + + lOrderedNdBlock.append( (yg, ndBlock)) #we want to order from top to bottom, so that TextRegions of different resolution are not interleaved + + lOrderedNdBlock.sort() + + for _, ndBlock in lOrderedNdBlock: yield ndBlock + + return + + + def _convertPageAnnotation(self, pnum, page, domNdPage): + """ + + """ + + #change: on each page we start by Heigh + bRestartAtEachPageWithHeigh = True + if bRestartAtEachPageWithHeigh: self._initSegmentationLabel() + + for nd in self._iter_TextRegionNodeTop2Bottom(domNdPage, page): + + try: + lbl = PageXml.getCustomAttr(nd, "structure", "type") + except PageXmlException: + nd.set(self.sSemAttr, self.sOther) + nd.set(self.sSgmAttr, self.sOther) + continue #this node has no annotation whatsoever + + if lbl in ["heading", "header", "page-number", "marginalia"]: + semLabel = lbl + sgmLabel = self.sOther #those elements are not part of a resolution + sResoNum = None + else: + o = self.creResolutionHumanLabel.match(lbl) + if not o: raise ValueError("%s is not a valid human annotation" % lbl) + semLabel = o.group(1) #"" for the resolution number + + #now decide on the segmentation label + sResoNum = o.group(2) + if not sResoNum: raise ValueError("%s is not a valid human annotation - missing resolution number" % lbl) + + #now switch between heigh and ho !! :)) + if self.prevResolutionNumber == sResoNum: + sgmLabel = self.prevSgmLbl + else: + sgmLabel = self._getNextSegmentationLabel(self.prevSgmLbl) + assert bRestartAtEachPageWithHeigh or sResoNum not in self.lSeenResoNum, "ERROR: the ordering of the block has not preserved resolution number contiguity" + self.lSeenResoNum.append(sResoNum) + + self.prevResolutionNumber, self.prevSgmLbl = sResoNum, sgmLabel + + + #always have a semantic label + sNewSemLbl = self.dAnnotMapping[semLabel] + assert sNewSemLbl + nd.set(self.sSemAttr, sNewSemLbl) #DU annotation + + #resolution parts also have a segmentation label and a resolution number + assert sgmLabel + nd.set(self.sSgmAttr, sgmLabel) #DU annotation + + if sResoNum: + nd.set(self.sNumAttr, sResoNum) + +class DU_BAR_Convert_v2(DU_BAR_Convert): + """ + For segmentation labels, we only use 'Heigh' or 'Ho' whatever the semantic label is, so that the task is purely a segmentation task. + + Heading indicate the start of a resolution, and is part of it. + Anything else (Header page-number, marginalia) is part of the resolution. + + """ + + def _initSegmentationLabel(self): + self.prevResolutionNumber = None + self._curSgmLbl = None + + def _switchSegmentationLabel(self): + """ + alternate beween HEIGH and HO, 1st is Heigh + """ + if self._curSgmLbl == None: + self._curSgmLbl = self.sSegmHeigh + else: + self._curSgmLbl = self.sSegmHeigh if self._curSgmLbl == self.sSegmHo else self.sSegmHo + return self._curSgmLbl + + def _getCurrentSegmentationLabel(self): + """ + self.curSgmLbl or Heigh if not yet set! + """ + if self._curSgmLbl == None: self._curSgmLbl = self.sSegmHeigh + return self._curSgmLbl + + def _convertPageAnnotation(self, pnum, page, domNdPage): + """ + + """ + for nd in self._iter_TextRegionNodeTop2Bottom(domNdPage, page): + + try: + sResoNum = None + lbl = PageXml.getCustomAttr(nd, "structure", "type") + + if lbl in ["heading"]: + semLabel = self.dAnnotMapping[lbl] + #heading may indicate a new resolution! + if self.prevResolutionNumber == None: + sgmLabel = self._getCurrentSegmentationLabel() #for instance 2 consecutive headings + else: + sgmLabel = self._switchSegmentationLabel() + self.prevResolutionNumber = None #so that next number does not switch Heigh/Ho label + elif lbl in ["header", "page-number", "marginalia"]: + #continuation of a resolution + semLabel = self.dAnnotMapping[lbl] + sgmLabel = self._getCurrentSegmentationLabel() + else: + o = self.creResolutionHumanLabel.match(lbl) + if not o: raise ValueError("%s is not a valid human annotation" % lbl) + semLabel = self.dAnnotMapping[o.group(1)] #"" for the resolution number + + #Here we have a resolution number! + sResoNum = o.group(2) + if not sResoNum: raise ValueError("%s is not a valid human annotation - missing resolution number" % lbl) + + #now switch between heigh and ho !! :)) + if self.prevResolutionNumber != None and self.prevResolutionNumber != sResoNum: + #we got a new number, so switching segmentation label! + sgmLabel = self._switchSegmentationLabel() + else: + #either same number or switching already done due to a heading + sgmLabel = self._getCurrentSegmentationLabel() + + self.prevResolutionNumber = sResoNum + + except PageXmlException: + semLabel = self.sOther + sgmLabel = self._getCurrentSegmentationLabel() + + nd.set(self.sSemAttr, semLabel) + nd.set(self.sSgmAttr, sgmLabel) + if sResoNum: + nd.set(self.sNumAttr, sResoNum) #only when the number is part of the humanannotation! + + +class DU_BAR_Convert_BIES(DU_BAR_Convert): + """ + For segmentation labels, we only use B I E S whatever the semantic label is, so that the task is purely a segmentation task. + + Heading indicate the start of a resolution, and is part of it. + Anything else (Header page-number, marginalia) is part of the resolution. + + """ + B = "B" + I = "I" + E = "E" + S = "S" + + def _initSegmentationLabel(self): + self._prevNd = None + self._prevNum = False + self._prevIsB = None + def _convertPageAnnotation(self, pnum, page, domNdPage): + """ + + """ + for nd in self._iter_TextRegionNodeTop2Bottom(domNdPage, page): + sResoNum = None + bCurrentIsAStart = None + try: + lbl = PageXml.getCustomAttr(nd, "structure", "type") + + if lbl == "heading": + semLabel = self.dAnnotMapping[lbl] + #heading indicate the start of a new resolution, unless the previous is already a start! + if self._prevIsB: + bCurrentIsAStart = False + else: + bCurrentIsAStart = True + self._prevNum = False #to prevent starting again when find the resolution number + elif lbl in ["header", "page-number", "marginalia"]: + semLabel = self.dAnnotMapping[lbl] + #continuation of a resolution, except at very beginning (first node) + if self._prevNd == None: + bCurrentIsAStart = True + else: + bCurrentIsAStart = False + else: + o = self.creResolutionHumanLabel.match(lbl) + if not o: + + if False: # strict + raise ValueError("%s is not a valid human annotation" % lbl) + else: + # relaxed + print(" ** WARNING ** strange annotation on node id=%s : '%s'"%(nd.get("id"), lbl)) + semLabel = self.dAnnotMapping[None] + #Here we have a resolution number! + sResoNum = self._prevNum + else: + semLabel = self.dAnnotMapping[o.group(1)] #"" for the resolution number + + #Here we have a resolution number! + sResoNum = o.group(2) + if not sResoNum: raise ValueError("%s is not a valid human annotation - missing resolution number" % lbl) + + if self._prevNum != False and self._prevNum != sResoNum: + #we got a new number, so switching segmentation label! + bCurrentIsAStart = True + else: + #either same number or switching already done due to a heading + bCurrentIsAStart = False + self._prevNum = sResoNum + + + except PageXmlException: + semLabel = self.sOther + bCurrentIsAStart = False + + #Now tagging!! + #Semantic (easy) + nd.set(self.sSemAttr, semLabel) + + # BIES, tough... + if bCurrentIsAStart: + if self._prevIsB: + #make previous a singleton! + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.S) + else: + #make previous a End + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.E) + self._prevIsB = True #for next cycle! + else: + if self._prevIsB: + #confirm previous a a B + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.B) + else: + #confirm previous as a I + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.I) + self._prevIsB = False #for next cycle! + + if sResoNum: nd.set(self.sNumAttr, sResoNum) #only when the number is part of the humanannotation! + self._prevNd = nd #for next cycle! + # end for + + if self._prevIsB: + #make previous a singleton! + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.S) + else: + #make previous a End + if self._prevNd: self._prevNd.set(self.sSgmAttr, self.E) + return + + +#------------------------------------------------------------------------------------------------------ +def test_RE(): + cre = DU_BAR_Convert.creResolutionHumanLabel + + o = cre.match("m103a") + assert o.group(1) == 'm' + assert o.group(2) == '103a' + + o = cre.match("103a") + assert o.group(1) == '' + assert o.group(2) == '103a' + + o = cre.match("103") + assert o.group(1) == '' + assert o.group(2) == '103' + + o = cre.match("az103a") + assert o == None + + +#------------------------------------------------------------------------------------------------------ + + +if __name__ == "__main__": + from optparse import OptionParser + + #prepare for the parsing of the command line + parser = OptionParser(usage="BAR annotation conversion", version="1.0") + +# parser.add_option("--tst", dest='lTst', action="append", type="string" +# , help="Test a model using the given annotated collection.") +# parser.add_option("--fold-init", dest='iFoldInitNum', action="store", type="int" +# , help="Initialize the file lists for parallel cross-validating a model on the given annotated collection. Indicate the number of folds.") +# parser.add_option("--jgjhg", dest='bFoldFinish', action="store_true" +# , help="Evaluate by cross-validation a model on the given annotated collection.") +# parser.add_option("-w", "--warm", dest='warm', action="store_true" +# , help="To make warm-startable model and warm-start if a model exist already.") + + #parse the command line + (options, args) = parser.parse_args() + + # --- + #doer = DU_BAR_Convert() + #doer = DU_BAR_Convert_v2() + doer = DU_BAR_Convert_BIES() + for sFilename in args: + print ("- Processing %s" % sFilename) + sOutputFilename = doer.convertDoc(sFilename) + print (" done --> %s" % sOutputFilename) + + print ("DONE.") + diff --git a/TranskribusDU/tasks/case_BAR/DU_BAR_sem.py b/TranskribusDU/tasks/case_BAR/DU_BAR_sem.py new file mode 100644 index 0000000..c471308 --- /dev/null +++ b/TranskribusDU/tasks/case_BAR/DU_BAR_sem.py @@ -0,0 +1,645 @@ +# -*- coding: utf-8 -*- + +""" + DU task for BAR - see https://read02.uibk.ac.at/wiki/index.php/Document_Understanding_BAR + + Copyright Xerox(C) 2017 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os + +import json + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from crf.Graph_MultiPageXml import Graph_MultiPageXml +from crf.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_woText, NodeType_PageXml_type +from tasks.DU_CRF_Task import DU_CRF_Task +from crf.FeatureDefinition_PageXml_std import FeatureDefinition_PageXml_StandardOnes +from graph.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText +from tasks import _checkFindColDir, _exit + + +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + +from gcn.DU_Model_ECN import DU_Model_GAT + + +from tasks.DU_BAR import main as m + +class DU_BAR_sem(DU_CRF_Task): + """ + We will do a typed CRF model for a DU task + , with the below labels + """ + sLabeledXmlFilenamePattern = "*.mpxml" #"*.bar_mpxml" + + bHTR = False # do we have text from an HTR? + bPerPage = False # do we work per document or per page? + bTextLine = True # if False then act as TextRegion + + #=== CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + #DEFINING THE CLASS OF GRAPH WE USE + if cls.bPerPage: + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + else: + DU_GRAPH = Graph_MultiPageXml + + #lLabels1 = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + lLabels1 = ['IGNORE', '577', '579', '581', '608', '32', '3431', '617', '3462', '3484', '615', '49', '3425', '73', '3', '3450', '2', '11', '70', '3451', '637', '77', '3447', '3476', '3467', '3494', '3493', '3461', '3434', '48', '3456', '35', '3482', '74', '3488', '3430', '17', '613', '625', '3427', '3498', '29', '3483', '3490', '362', '638a', '57', '616', '3492', '10', '630', '24', '3455', '3435', '8', '15', '3499', '27', '3478', '638b', '22', '3469', '3433', '3496', '624', '59', '622', '75', '640', '1', '19', '642', '16', '25', '3445', '3463', '3443', '3439', '3436', '3479', '71', '3473', '28', '39', '361', '65', '3497', '578', '72', '634', '3446', '627', '43', '62', '34', '620', '76', '23', '68', '631', '54', '3500', '3480', '37', '3440', '619', '44', '3466', '30', '3487', '45', '61', '3452', '3491', '623', '633', '53', '66', '67', '69', '643', '58', '632', '636', '7', '641', '51', '3489', '3471', '21', '36', '3468', '4', '576', '46', '63', '3457', '56', '3448', '3441', '618', '52', '3429', '3438', '610', '26', '609', '3444', '612', '3485', '3465', '41', '20', '3464', '3477', '3459', '621', '3432', '60', '3449', '626', '628', '614', '47', '3454', '38', '3428', '33', '12', '3426', '3442', '3472', '13', '639', '3470', '611', '6', '40', '14', '3486', '31', '3458', '3437', '3453', '55', '3424', '3481', '635', '64', '629', '3460', '50', '9', '18', '42', '3495', '5', '580'] + + + #the converter changed to other unlabelled TextRegions or 'marginalia' TRs + lIgnoredLabels1 = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + +# lActuallySeen = None +# if lActuallySeen: +# print( "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") +# lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] +# lLabels = [lLabels[i] for i in lActuallySeen ] +# print( len(lLabels) , lLabels) +# print( len(lIgnoredLabels) , lIgnoredLabels) + if cls.bHTR: + ntClass = NodeType_PageXml_type + else: + #ignore text + ntClass = NodeType_PageXml_type_woText + + nt1 = ntClass("bar" #some short prefix because labels below are prefixed with it + , lLabels1 + , lIgnoredLabels1 + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt1.setLabelAttribute("DU_num") + if cls.bTextLine: + nt1.setXpathExpr( (".//pc:TextRegion/pc:TextLine[@DU_num]" #how to find the nodes + , "./pc:TextEquiv") + ) + else: + nt1.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt1) + + return DU_GRAPH + + + # =============================================================================================================== + + + + # """ + # if you play with a toy collection, which does not have all expected classes, you can reduce those. + # """ + # + # lActuallySeen = None + # if lActuallySeen: + # print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING" + # lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + # lLabels = [lLabels[i] for i in lActuallySeen ] + # print len(lLabels) , lLabels + # print len(lIgnoredLabels) , lIgnoredLabels + # nbClass = len(lLabels) + 1 #because the ignored labels will become OTHER + + + + #=== CONFIGURATION ==================================================================== + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + + if self.bHTR: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes + dFeatureConfig = { + 'n_tfidf_node':100, 't_ngrams_node':(1,2), 'b_tfidf_node_lc':False + , 'n_tfidf_edge':100, 't_ngrams_edge':(1,2), 'b_tfidf_edge_lc':False } + else: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + dFeatureConfig = { } + #'n_tfidf_node':None, 't_ngrams_node':None, 'b_tfidf_node_lc':None + #, 'n_tfidf_edge':None, 't_ngrams_edge':None, 'b_tfidf_edge_lc':None } + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig = dFeatureConfig + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 16 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment=sComment + , cFeatureDefinition=cFeatureDefinition +# , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText +# , dFeatureConfig = { +# #config for the extractor of nodes of each type +# "text": None, +# "sprtr": None, +# #config for the extractor of edges of each type +# "text_text": None, +# "text_sprtr": None, +# "sprtr_text": None, +# "sprtr_sprtr": None +# } + ) + + traceln("- classes: ", self.getGraphClass().getLabelNameList()) + + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + + #=== END OF CONFIGURATION ============================================================= + + + def predict(self, lsColDir):#,sDocId): + """ + Return the list of produced files + """ +# self.sXmlFilenamePattern = "*.a_mpxml" + return DU_CRF_Task.predict(self, lsColDir)#,sDocId) + + + def runForExternalMLMethod(self, lsColDir, storeX, applyY, bRevertEdges=False): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_CRF_Task.runForExternalMLMethod(self, lsColDir, storeX, applyY, bRevertEdges) + + + + +from tasks.DU_ECN_Task import DU_ECN_Task +import gcn.DU_Model_ECN +class DU_ABPTable_ECN(DU_ECN_Task): + """ + ECN Models + """ + bHTR = False # do we have text from an HTR? + bPerPage = False # do we work per document or per page? + bTextLine = True # if False then act as TextRegion + + sMetadata_Creator = "NLE Document Understanding ECN" + sXmlFilenamePattern = "*.mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.mpxml" + + + sLabeledXmlFilenameEXT = ".mpxml" + + dLearnerConfig = None + + #dLearnerConfig = {'nb_iter': 50, + # 'lr': 0.001, + # 'num_layers': 3, + # 'nconv_edge': 10, + # 'stack_convolutions': True, + # 'node_indim': -1, + # 'mu': 0.0, + # 'dropout_rate_edge': 0.0, + # 'dropout_rate_edge_feat': 0.0, + # 'dropout_rate_node': 0.0, + # 'ratio_train_val': 0.15, + # #'activation': tf.nn.tanh, Problem I can not serialize function HERE + # } + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + #lLabels = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + lLabels = ['IGNORE', '577', '579', '581', '608', '32', '3431', '617', '3462', '3484', '615', '49', '3425', '73', '3', '3450', '2', '11', '70', '3451', '637', '77', '3447', '3476', '3467', '3494', '3493', '3461', '3434', '48', '3456', '35', '3482', '74', '3488', '3430', '17', '613', '625', '3427', '3498', '29', '3483', '3490', '362', '638a', '57', '616', '3492', '10', '630', '24', '3455', '3435', '8', '15', '3499', '27', '3478', '638b', '22', '3469', '3433', '3496', '624', '59', '622', '75', '640', '1', '19', '642', '16', '25', '3445', '3463', '3443', '3439', '3436', '3479', '71', '3473', '28', '39', '361', '65', '3497', '578', '72', '634', '3446', '627', '43', '62', '34', '620', '76', '23', '68', '631', '54', '3500', '3480', '37', '3440', '619', '44', '3466', '30', '3487', '45', '61', '3452', '3491', '623', '633', '53', '66', '67', '69', '643', '58', '632', '636', '7', '641', '51', '3489', '3471', '21', '36', '3468', '4', '576', '46', '63', '3457', '56', '3448', '3441', '618', '52', '3429', '3438', '610', '26', '609', '3444', '612', '3485', '3465', '41', '20', '3464', '3477', '3459', '621', '3432', '60', '3449', '626', '628', '614', '47', '3454', '38', '3428', '33', '12', '3426', '3442', '3472', '13', '639', '3470', '611', '6', '40', '14', '3486', '31', '3458', '3437', '3453', '55', '3424', '3481', '635', '64', '629', '3460', '50', '9', '18', '42', '3495', '5', '580'] + + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + if cls.bPerPage: + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + else: + DU_GRAPH = Graph_MultiPageXml + + + + lActuallySeen = None + if lActuallySeen: + print("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + print(len(lLabels), lLabels) + print(len(lIgnoredLabels), lIgnoredLabels) + + if cls.bHTR: + ntClass = NodeType_PageXml_type + else: + #ignore text + ntClass = NodeType_PageXml_type_woText + + + + # DEFINING THE CLASS OF GRAPH WE USE + nt = ntClass("bar" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + + + + nt.setLabelAttribute("DU_num") + if cls.bTextLine: + nt.setXpathExpr( (".//pc:TextRegion/pc:TextLine[@DU_num]" #how to find the nodes + , "./pc:TextEquiv") + ) + else: + nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + + + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + print ( self.bHTR) + + if self.bHTR: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':300, 't_ngrams_node':(1,4), 'b_tfidf_node_lc':False + , 'n_tfidf_edge':300, 't_ngrams_edge':(1,4), 'b_tfidf_edge_lc':False } + else: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + dFeatureConfig = { } + + + if sComment is None: sComment = sModelName + + + if dLearnerConfigArg is not None and "ecn_ensemble" in dLearnerConfigArg: + print('ECN_ENSEMBLE') + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , + dLearnerConfig=dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment=sComment + , cFeatureDefinition= cFeatureDefinition + , cModelClass=gcn.DU_Model_ECN.DU_Ensemble_ECN + ) + + + else: + #Default Case Single Model + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , dLearnerConfig= dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment= sComment + , cFeatureDefinition=cFeatureDefinition + ) + + #if options.bBaseline: + # self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.mpxml" + return DU_ECN_Task.predict(self, lsColDir) + + + + +class DU_ABPTable_GAT(DU_ECN_Task): + """ + ECN Models + """ + bHTR = True # do we have text from an HTR? + bPerPage = True # do we work per document or per page? + bTextLine = True # if False then act as TextRegion + + sMetadata_Creator = "NLE Document Understanding GAT" + + + sXmlFilenamePattern = "*.bar_mpxml" + + # sLabeledXmlFilenamePattern = "*.a_mpxml" + sLabeledXmlFilenamePattern = "*.bar_mpxml" + + sLabeledXmlFilenameEXT = ".bar_mpxml" + + + dLearnerConfigOriginalGAT ={ + 'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 2,#2 Train Acc is lower 5 overfit both reach 81% accuracy on Fold-1 + 'nb_attention': 5, + 'stack_convolutions': True, + # 'node_indim': 50 , worked well 0.82 + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention': 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience": 50, + "mu": 0.00001, + "original_model" : True + + } + + + dLearnerConfigNewGAT = {'nb_iter': 500, + 'lr': 0.001, + 'num_layers': 5, + 'nb_attention': 5, + 'stack_convolutions': True, + 'node_indim': -1, + 'dropout_rate_node': 0.0, + 'dropout_rate_attention' : 0.0, + 'ratio_train_val': 0.15, + "activation_name": 'tanh', + "patience":50, + "original_model": False, + "attn_type":0 + } + dLearnerConfig = dLearnerConfigNewGAT + #dLearnerConfig = dLearnerConfigOriginalGAT + # === CONFIGURATION ==================================================================== + @classmethod + def getConfiguredGraphClass(cls): + """ + In this class method, we must return a configured graph class + """ + lLabels = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + lIgnoredLabels = None + + """ + if you play with a toy collection, which does not have all expected classes, you can reduce those. + """ + + lActuallySeen = None + if lActuallySeen: + print("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen] + print(len(lLabels), lLabels) + print(len(lIgnoredLabels), lIgnoredLabels) + + + # DEFINING THE CLASS OF GRAPH WE USE + if cls.bPerPage: + DU_GRAPH = Graph_MultiSinglePageXml # consider each age as if indep from each other + else: + DU_GRAPH = Graph_MultiPageXml + + if cls.bHTR: + ntClass = NodeType_PageXml_type + else: + #ignore text + ntClass = NodeType_PageXml_type_woText + + + nt = ntClass("bar" # some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False # no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v / 3)) + # we reduce overlap in this way + ) + nt.setLabelAttribute("DU_sem") + if cls.bTextLine: + nt.setXpathExpr( (".//pc:TextRegion/pc:TextLine" #how to find the nodes + , "./pc:TextEquiv") + ) + else: + nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + + + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + def __init__(self, sModelName, sModelDir, sComment=None,dLearnerConfigArg=None): + if self.bHTR: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':500, 't_ngrams_node':(2,4), 'b_tfidf_node_lc':False + , 'n_tfidf_edge':250, 't_ngrams_edge':(2,4), 'b_tfidf_edge_lc':False } + else: + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText + dFeatureConfig = { 'bMultiPage':False, 'bMirrorPage':False + , 'n_tfidf_node':None, 't_ngrams_node':None, 'b_tfidf_node_lc':None + , 'n_tfidf_edge':None, 't_ngrams_edge':None, 'b_tfidf_edge_lc':None } + + + if sComment is None: sComment = sModelName + + + DU_ECN_Task.__init__(self + , sModelName, sModelDir + , dFeatureConfig=dFeatureConfig + , dLearnerConfig= dLearnerConfigArg if dLearnerConfigArg is not None else self.dLearnerConfig + , sComment=sComment + , cFeatureDefinition=cFeatureDefinition + , cModelClass=DU_Model_GAT + ) + + if options.bBaseline: + self.bsln_mdl = self.addBaseline_LogisticRegression() # use a LR model trained by GridSearch as baseline + + # === END OF CONFIGURATION ============================================================= + def predict(self, lsColDir): + """ + Return the list of produced files + """ + self.sXmlFilenamePattern = "*.bar_mpxml" + return DU_ECN_Task.predict(self, lsColDir) + + + + +# ---------------------------------------------------------------------------- + +def main(sModelDir, sModelName, options): + if options.use_ecn: + if options.ecn_json_config is not None and options.ecn_json_config is not []: + f = open(options.ecn_json_config[0]) + djson=json.loads(f.read()) + + if "ecn_learner_config" in djson: + dLearnerConfig=djson["ecn_learner_config"] + f.close() + doer = DU_ABPTable_ECN(sModelName, sModelDir,dLearnerConfigArg=dLearnerConfig) + elif "ecn_ensemble" in djson: + dLearnerConfig = djson + f.close() + doer = DU_ABPTable_ECN(sModelName, sModelDir, dLearnerConfigArg=dLearnerConfig) + + else: + doer = DU_ABPTable_ECN(sModelName, sModelDir) + elif options.use_gat: + if options.gat_json_config is not None and options.gat_json_config is not []: + + f = open(options.gat_json_config[0]) + djson=json.loads(f.read()) + dLearnerConfig=djson["gat_learner_config"] + f.close() + doer = DU_ABPTable_GAT(sModelName, sModelDir,dLearnerConfigArg=dLearnerConfig) + + else: + doer = DU_ABPTable_GAT(sModelName, sModelDir) + + else: + doer = m(DU_BAR_sem) + + + if options.rm: + doer.rm() + return + + + + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] + + traceln("- classes: ", doer.getGraphClass().getLabelNameList()) + + ## use. a_mpxml files + doer.sXmlFilenamePattern = doer.sLabeledXmlFilenamePattern + + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, test_size=0.25, random_state=None, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm, options.pkl) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + + + + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None, options.pkl) + import graph.GraphModel + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm, options.pkl) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for GridSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + if options.bDetailedReport: + traceln(tstReport.getDetailledReport()) + import graph.GraphModel + for test in lTst: + sReportPickleFilename = os.path.join('..',test, sModelName + "__report.pkl") + traceln('Report dumped into %s'%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, tstReport) + + if lRun: + if options.storeX or options.applyY: + try: doer.load() + except: pass #we only need the transformer + lsOutputFilename = doer.runForExternalMLMethod(lRun, options.storeX, options.applyY, options.bRevertEdges) + else: + doer.load() + lsOutputFilename = doer.predict(lRun) + + traceln("Done, see in:\n %s"%lsOutputFilename) + + +# ---------------------------------------------------------------------------- + + + + + +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) +# parser.add_option("--annotate", dest='bAnnotate', action="store_true",default=False, help="Annotate the textlines with BIES labels") + + #FOR GCN + parser.add_option("--revertEdges", dest='bRevertEdges', action="store_true", help="Revert the direction of the edges") + parser.add_option("--detail", dest='bDetailedReport', action="store_true", default=False,help="Display detailled reporting (score per document)") + parser.add_option("--baseline", dest='bBaseline', action="store_true", default=False, help="report baseline method") + parser.add_option("--ecn",dest='use_ecn',action="store_true", default=False, help="wether to use ECN Models") + parser.add_option("--ecn_config", dest='ecn_json_config',action="append", type="string", help="The Config files for the ECN Model") + parser.add_option("--gat", dest='use_gat', action="store_true", default=False, help="wether to use ECN Models") + parser.add_option("--gat_config", dest='gat_json_config', action="append", type="string", + help="The Config files for the Gat Model") + # --- + #parse the command line + (options, args) = parser.parse_args() + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + main(sModelDir, sModelName, options) + diff --git a/TranskribusDU/tasks/case_BAR/DU_BAR_sem_sgm.py b/TranskribusDU/tasks/case_BAR/DU_BAR_sem_sgm.py new file mode 100644 index 0000000..55e3e0a --- /dev/null +++ b/TranskribusDU/tasks/case_BAR/DU_BAR_sem_sgm.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- + +""" + DU task for BAR - see https://read02.uibk.ac.at/wiki/index.php/Document_Understanding_BAR + + Copyright Xerox(C) 2017 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from crf.Graph_MultiPageXml import FactorialGraph_MultiContinuousPageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_woText +from .DU_CRF_Task import DU_FactorialCRF_Task +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_T_PageXml_StandardOnes_noText +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + +from .DU_BAR import main + +class DU_BAR_sem_sgm(DU_FactorialCRF_Task): + """ + We will do a Factorial CRF model using the Multitype CRF + , with the below labels + """ + sLabeledXmlFilenamePattern = "*.du_mpxml" + + # =============================================================================================================== + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = FactorialGraph_MultiContinuousPageXml + + #--------------------------------------------- + lLabels1 = ['heading', 'header', 'page-number', 'resolution-number', 'resolution-marginalia', 'resolution-paragraph', 'other'] + + nt1 = NodeType_PageXml_type_woText("sem" #some short prefix because labels below are prefixed with it + , lLabels1 + , None #keep this to None, unless you know very well what you do. (FactorialCRF!) + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt1.setLabelAttribute("DU_sem") + nt1.setXpathExpr( (".//pc:TextRegion" #how to find the nodes, MUST be same as for other node type!! (FactorialCRF!) + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt1) + + #--------------------------------------------- + #lLabels2 = ['heigh', 'ho', 'other'] + #lLabels2 = ['heigh', 'ho'] + lLabels2 = ['B', 'I', 'E'] #we never see any S... , 'S'] + lLabels2 = ['B', 'I', 'E', 'S', 'O'] #we never see any S... , 'S'] + + nt2 = NodeType_PageXml_type_woText("sgm" #some short prefix because labels below are prefixed with it + , lLabels2 + , None #keep this to None, unless you know very well what you do. (FactorialCRF!) + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt2.setLabelAttribute("DU_sgm") + nt2.setXpathExpr( (".//pc:TextRegion" #how to find the nodes, MUST be same as for other node type!! (FactorialCRF!) + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt2) + + #=== CONFIGURATION ==================================================================== + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + +# #edge feature extractor config is a bit teddious... +# dFeatureConfig = { lbl:None for lbl in self.lLabels1+self.lLabels2 } +# for lbl1 in self.lLabels1: +# for lbl2 in self.lLabels2: +# dFeatureConfig["%s_%s"%(lbl1, lbl2)] = None + + DU_FactorialCRF_Task.__init__(self + , sModelName, sModelDir + , self.DU_GRAPH + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 8 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment=sComment + , cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText +# , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText +# { +# #config for the extractor of nodes of each type +# "text": None, +# "sprtr": None, +# #config for the extractor of edges of each type +# "text_text": None, +# "text_sprtr": None, +# "sprtr_text": None, +# "sprtr_sprtr": None +# } + ) + + traceln("- classes: ", self.DU_GRAPH.getLabelNameList()) + + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + + #=== END OF CONFIGURATION ============================================================= + + + def predict(self, lsColDir,sDocId): + """ + Return the list of produced files + """ +# self.sXmlFilenamePattern = "*.a_mpxml" + return DU_FactorialCRF_Task.predict(self, lsColDir,sDocId) + + +if __name__ == "__main__": + main(DU_BAR_sem_sgm) \ No newline at end of file diff --git a/TranskribusDU/tasks/case_BAR/DU_BAR_sgm.py b/TranskribusDU/tasks/case_BAR/DU_BAR_sgm.py new file mode 100644 index 0000000..0b05f60 --- /dev/null +++ b/TranskribusDU/tasks/case_BAR/DU_BAR_sgm.py @@ -0,0 +1,125 @@ +# -*- coding: utf-8 -*- + +""" + DU task for BAR - see https://read02.uibk.ac.at/wiki/index.php/Document_Understanding_BAR + + Copyright Xerox(C) 2017 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln + +from crf.Graph_MultiPageXml import Graph_MultiContinousPageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_woText +from DU_CRF_Task import DU_CRF_Task +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_T_PageXml_StandardOnes_noText +from crf.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + +from DU_BAR import main + +class DU_BAR_sgm(DU_CRF_Task): + """ + We will do a typed CRF model for a DU task + , with the below labels + """ + sLabeledXmlFilenamePattern = "*.du_mpxml" + + # =============================================================================================================== + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiContinousPageXml + + + #lLabels2 = ['heigh', 'ho', 'other'] + #lLabels2 = ['heigh', 'ho'] + lLabels2 = ['B', 'I', 'E'] #we never see any S... , 'S'] + + # Some TextRegion have no segmentation label at all, and were labelled'other' by the converter + lIgnoredLabels2 = None + + # """ + # if you play with a toy collection, which does not have all expected classes, you can reduce those. + # """ + # + # lActuallySeen = None + # if lActuallySeen: + # print "REDUCING THE CLASSES TO THOSE SEEN IN TRAINING" + # lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + # lLabels = [lLabels[i] for i in lActuallySeen ] + # print len(lLabels) , lLabels + # print len(lIgnoredLabels) , lIgnoredLabels + # nbClass = len(lLabels) + 1 #because the ignored labels will become OTHER + + nt2 = NodeType_PageXml_type_woText("sgm" #some short prefix because labels below are prefixed with it + , lLabels2 + , lIgnoredLabels2 + , False #no label means OTHER + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt2.setLabelAttribute("DU_sgm") + nt2.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt2) + + #=== CONFIGURATION ==================================================================== + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , self.DU_GRAPH + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 8 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment=sComment + , cFeatureDefinition=FeatureDefinition_PageXml_StandardOnes_noText +# , cFeatureDefinition=FeatureDefinition_T_PageXml_StandardOnes_noText +# , dFeatureConfig = { +# #config for the extractor of nodes of each type +# "text": None, +# "sprtr": None, +# #config for the extractor of edges of each type +# "text_text": None, +# "text_sprtr": None, +# "sprtr_text": None, +# "sprtr_sprtr": None +# } + ) + + traceln("- classes: ", self.DU_GRAPH.getLabelNameList()) + + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + + #=== END OF CONFIGURATION ============================================================= + + + def predict(self, lsColDir,sDocId): + """ + Return the list of produced files + """ +# self.sXmlFilenamePattern = "*.a_mpxml" + return DU_CRF_Task.predict(self, lsColDir,sDocId) + + +if __name__ == "__main__": + main(DU_BAR_sgm) \ No newline at end of file diff --git a/TranskribusDU/tasks/case_BAR/__init__.py b/TranskribusDU/tasks/case_BAR/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TranskribusDU/tasks/case_GTBooks/DU_GTBooks.py b/TranskribusDU/tasks/case_GTBooks/DU_GTBooks.py new file mode 100644 index 0000000..0faead3 --- /dev/null +++ b/TranskribusDU/tasks/case_GTBooks/DU_GTBooks.py @@ -0,0 +1,237 @@ +# -*- coding: utf-8 -*- + +""" + Example DU task for Dodge, using the logit textual feature extractor + + Copyright Xerox(C) 2017 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import sys, os + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _checkFindColDir, _exit + +from graph.Graph_MultiPageXml import Graph_MultiPageXml +from graph.NodeType_PageXml import NodeType_PageXml_type_NestedText +from tasks.DU_Task_Factory import DU_Task_Factory +from tasks.DU_CRF_Task import DU_CRF_Task +from graph.FeatureDefinition_PageXml_logit_v2 import FeatureDefinition_PageXml_LogitExtractorV2 + +# =============================================================================================================== + +lLabels = ['TOC-entry' #0 + , 'caption' + , 'catch-word' + , 'footer' + , 'footnote' #4 + , 'footnote-continued' + , 'header' #6 + , 'heading' #7 + , 'marginalia' + , 'page-number' #9 + , 'paragraph' #10 + , 'signature-mark'] +lIgnoredLabels = None + +nbClass = len(lLabels) + +""" +if you play with a toy collection, which does not have all expected classes, you can reduce those. +""" +lActuallySeen = [4, 6, 7, 9, 10] +#lActuallySeen = [4, 6] +""" + 0- TOC-entry 5940 occurences ( 2%) ( 2%) + 1- caption 707 occurences ( 0%) ( 0%) + 2- catch-word 201 occurences ( 0%) ( 0%) + 3- footer 11 occurences ( 0%) ( 0%) + 4- footnote 36942 occurences ( 11%) ( 11%) + 5- footnote-continued 1890 occurences ( 1%) ( 1%) + 6- header 15910 occurences ( 5%) ( 5%) + 7- heading 18032 occurences ( 6%) ( 6%) + 8- marginalia 4292 occurences ( 1%) ( 1%) + 9- page-number 40236 occurences ( 12%) ( 12%) + 10- paragraph 194927 occurences ( 60%) ( 60%) + 11- signature-mark 4894 occurences ( 2%) ( 2%) +""" +lActuallySeen = None +if lActuallySeen: + traceln("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen ] + traceln(len(lLabels) , lLabels) + traceln(len(lIgnoredLabels) , lIgnoredLabels) + nbClass = len(lLabels) + 1 #because the ignored labels will become OTHER + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiPageXml + nt = NodeType_PageXml_type_NestedText("gtb" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , True #no label means OTHER + ) +else: + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiPageXml + nt = NodeType_PageXml_type_NestedText("gtb" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False #no label means OTHER + ) +nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) +DU_GRAPH.addNodeType(nt) + +""" +The constraints must be a list of tuples like ( , , , ) +where: +- operator is one of 'XOR' 'XOROUT' 'ATMOSTONE' 'OR' 'OROUT' 'ANDOUT' 'IMPLY' +- states is a list of unary state names, 1 per involved unary. If the states are all the same, you can pass it directly as a single string. +- negated is a list of boolean indicated if the unary must be negated. Again, if all values are the same, pass a single boolean value instead of a list +""" +if False: + DU_GRAPH.setPageConstraint( [ ('ATMOSTONE', nt, 'pnum' , False) #0 or 1 catch_word per page + , ('ATMOSTONE', nt, 'title' , False) #0 or 1 heading pare page + ] ) + +# =============================================================================================================== + + +class DU_GTBooks(DU_CRF_Task): + """ + We will do a CRF model for a DU task + , working on a DS XML document at BLOCK level + , with the below labels + """ + sXmlFilenamePattern = "*.mpxml" + +# #In case you want to change the Logistic Regression gird search parameters... +# dGridSearch_LR_conf = {'C':[0.01, 0.1, 1.0, 10.0] } #Grid search parameters for LR baseline method training +# dGridSearch_LR_n_jobs = 4 #Grid search: number of jobs + + #=== CONFIGURATION ==================================================================== + def __init__(self, sModelName, sModelDir, sComment=None, C=None, tol=None, njobs=None, max_iter=None, inference_cache=None): + #NOTE: we might get a list in C tol max_iter inference_cache (in case of gridsearch) + + DU_CRF_Task.__init__(self + , sModelName, sModelDir + , DU_GRAPH + , dFeatureConfig = { + 'nbClass' : nbClass + , 't_ngrams_node' : (2,4) + , 'b_node_lc' : False + , 't_ngrams_edge' : (2,4) + , 'b_edge_lc' : False + , 'n_jobs' : 5 #n_jobs when fitting the internal Logit feat extractor model by grid search + } + , dLearnerConfig = { + 'C' : .1 if C is None else C + , 'njobs' : 5 if njobs is None else njobs + , 'inference_cache' : 50 if inference_cache is None else inference_cache + #, 'tol' : .1 + , 'tol' : .05 if tol is None else tol + , 'save_every' : 50 #save every 50 iterations,for warm start + , 'max_iter' : 1000 if max_iter is None else max_iter + } + , sComment=sComment + , cFeatureDefinition=FeatureDefinition_PageXml_LogitExtractorV2 + ) + + self.setNbClass(nbClass) #so that we check if all classes are represented in the training set + + self.bsln_mdl = self.addBaseline_LogisticRegression() #use a LR model trained by GridSearch as baseline + #=== END OF CONFIGURATION ============================================================= + + +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0], version) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + # --- + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + _exit(usage, 1, e) + + doer = DU_GTBooks(sModelName, sModelDir, + C = options.crf_C, + tol = options.crf_tol, + njobs = options.crf_njobs, + max_iter = options.max_iter, + inference_cache = options.crf_inference_cache) + + if options.rm: + doer.rm() + sys.exit(0) + + traceln("- classes: ", DU_GRAPH.getLabelNameList()) + + if options.best_params: + dBestParams = doer.getModelClass().loadBestParams(sModelDir, options.best_params) + doer.setLearnerConfiguration(dBestParams) + + lTrn, lTst, lRun, lFold = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun, options.lFold]] + + if options.iFoldInitNum or options.iFoldRunNum or options.bFoldFinish: + if options.iFoldInitNum: + """ + initialization of a cross-validation + """ + splitter, ts_trn, lFilename_trn = doer._nfold_Init(lFold, options.iFoldInitNum, bStoreOnDisk=True) + elif options.iFoldRunNum: + """ + Run one fold + """ + oReport = doer._nfold_RunFoldFromDisk(options.iFoldRunNum, options.warm) + traceln(oReport) + elif options.bFoldFinish: + tstReport = doer._nfold_Finish() + traceln(tstReport) + else: + assert False, "Internal error" + #no more processing!! + exit(0) + #------------------- + + if lFold: + loTstRpt = doer.nfold_Eval(lFold, 3, .25, None) + import graph.GraphModel + sReportPickleFilename = os.path.join(sModelDir, sModelName + "__report.txt") + traceln("Results are in %s"%sReportPickleFilename) + graph.GraphModel.GraphModel.gzip_cPickle_dump(sReportPickleFilename, loTstRpt) + elif lTrn: + doer.train_save_test(lTrn, lTst, options.warm) + try: traceln("Baseline best estimator: %s"%doer.bsln_mdl.best_params_) #for GridSearch + except: pass + traceln(" --- CRF Model ---") + traceln(doer.getModel().getModelInfo()) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + + if lRun: + doer.load() + lsOutputFilename = doer.predict(lRun) + traceln("Done, see in:\n %s"%lsOutputFilename) + diff --git a/TranskribusDU/tasks/case_GTBooks/DU_GTBooks_BL.py b/TranskribusDU/tasks/case_GTBooks/DU_GTBooks_BL.py new file mode 100644 index 0000000..62f6d9b --- /dev/null +++ b/TranskribusDU/tasks/case_GTBooks/DU_GTBooks_BL.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- + +""" + Example DU task for Dodge, using the logit textual feature extractor + + Copyright Xerox(C) 2017 JL. Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union�s Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" +import sys, os +from crf import FeatureDefinition_PageXml_GTBooks + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from tasks import _checkFindColDir, _exit + +from crf.Graph_MultiPageXml import Graph_MultiPageXml +from crf.NodeType_PageXml import NodeType_PageXml_type_NestedText +from DU_CRF_Task import DU_CRF_Task +from DU_BL_Task import DU_Baseline +from crf.FeatureDefinition_PageXml_GTBooks import FeatureDefinition_GTBook + +# =============================================================================================================== + +lLabels = ['TOC-entry' #0 + , 'caption' + , 'catch-word' + , 'footer' + , 'footnote' #4 + , 'footnote-continued' + , 'header' #6 + , 'heading' #7 + , 'marginalia' + , 'page-number' #9 + , 'paragraph' #10 + , 'signature-mark'] +lIgnoredLabels = None + +nbClass = len(lLabels) + +""" +if you play with a toy collection, which does not have all expected classes, you can reduce those. +""" +lActuallySeen = [4, 6, 7, 9, 10] +#lActuallySeen = [4, 6] +""" + 0- TOC-entry 5940 occurences ( 2%) ( 2%) + 1- caption 707 occurences ( 0%) ( 0%) + 2- catch-word 201 occurences ( 0%) ( 0%) + 3- footer 11 occurences ( 0%) ( 0%) + 4- footnote 36942 occurences ( 11%) ( 11%) + 5- footnote-continued 1890 occurences ( 1%) ( 1%) + 6- header 15910 occurences ( 5%) ( 5%) + 7- heading 18032 occurences ( 6%) ( 6%) + 8- marginalia 4292 occurences ( 1%) ( 1%) + 9- page-number 40236 occurences ( 12%) ( 12%) + 10- paragraph 194927 occurences ( 60%) ( 60%) + 11- signature-mark 4894 occurences ( 2%) ( 2%) +""" +lActuallySeen = None +if lActuallySeen: + traceln("REDUCING THE CLASSES TO THOSE SEEN IN TRAINING") + lIgnoredLabels = [lLabels[i] for i in range(len(lLabels)) if i not in lActuallySeen] + lLabels = [lLabels[i] for i in lActuallySeen ] + traceln(len(lLabels) , lLabels) + traceln(len(lIgnoredLabels) , lIgnoredLabels) + nbClass = len(lLabels) + 1 #because the ignored labels will become OTHER + + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiPageXml + nt = NodeType_PageXml_type_NestedText("gtb" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , True #no label means OTHER + ) +else: + #DEFINING THE CLASS OF GRAPH WE USE + DU_GRAPH = Graph_MultiPageXml + nt = NodeType_PageXml_type_NestedText("gtb" #some short prefix because labels below are prefixed with it + , lLabels + , lIgnoredLabels + , False #no label means OTHER + ) +nt.setXpathExpr( (".//pc:TextRegion" #how to find the nodes + , "./pc:TextEquiv") #how to get their text + ) +DU_GRAPH.addNodeType(nt) + +""" +The constraints must be a list of tuples like ( , , , ) +where: +- operator is one of 'XOR' 'XOROUT' 'ATMOSTONE' 'OR' 'OROUT' 'ANDOUT' 'IMPLY' +- states is a list of unary state names, 1 per involved unary. If the states are all the same, you can pass it directly as a single string. +- negated is a list of boolean indicated if the unary must be negated. Again, if all values are the same, pass a single boolean value instead of a list +""" +if False: + DU_GRAPH.setPageConstraint( [ ('ATMOSTONE', nt, 'pnum' , False) #0 or 1 catch_word per page + , ('ATMOSTONE', nt, 'title' , False) #0 or 1 heading pare page + ] ) + +# =============================================================================================================== + + +class DU_BL_V1(DU_Baseline): + def __init__(self, sModelName, sModelDir,logitID,sComment=None): + DU_Baseline.__init__(self, sModelName, sModelDir,DU_GRAPH,logitID) + + + +if __name__ == "__main__": + + version = "v.01" + usage, description, parser = DU_CRF_Task.getBasicTrnTstRunOptionParser(sys.argv[0], version) + + # --- + #parse the command line + (options, args) = parser.parse_args() + # --- + try: + sModelDir, sModelName = args + except Exception as e: + _exit(usage, 1, e) + + doer = DU_BL_V1(sModelName, sModelDir,'logit_5') + + if options.rm: + doer.rm() + sys.exit(0) + + traceln("- classes: ", DU_GRAPH.getLabelNameList()) + + if hasattr(options,'l_train_files') and hasattr(options,'l_test_files'): + f=open(options.l_train_files) + lTrn=[] + for l in f: + fname=l.rstrip() + lTrn.append(fname) + f.close() + + g=open(options.l_test_files) + lTst=[] + for l in g: + fname=l.rstrip() + lTst.append(fname) + + tstReport=doer.train_save_test(lTrn, lTst, options.warm,filterFilesRegexp=False) + traceln(tstReport) + + + else: + + lTrn, lTst, lRun = [_checkFindColDir(lsDir) for lsDir in [options.lTrn, options.lTst, options.lRun]] + + if lTrn: + doer.train_save_test(lTrn, lTst, options.warm) + elif lTst: + doer.load() + tstReport = doer.test(lTst) + traceln(tstReport) + + if lRun: + doer.load() + lsOutputFilename = doer.predict(lRun) + traceln("Done, see in:\n %s"%lsOutputFilename) + diff --git a/TranskribusDU/tasks/case_GTBooks/__init__.py b/TranskribusDU/tasks/case_GTBooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/TranskribusDU/tasks/cluster2Region.py b/TranskribusDU/tasks/cluster2Region.py new file mode 100644 index 0000000..fa155b1 --- /dev/null +++ b/TranskribusDU/tasks/cluster2Region.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- + +""" +Transform clusters into TextRegions and populate them with TextLines + +Created on August 2019 + +Copyright NAVER LABS Europe 2019 +@author: Hervé Déjean +""" + +import sys, os, glob +from optparse import OptionParser +from copy import deepcopy +from collections import Counter +from collections import defaultdict + +from lxml import etree +import numpy as np +from shapely.ops import cascaded_union + + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from xml_formats.PageXml import PageXml +from util.Shape import ShapeLoader +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + + +def getClusterCoords(lElts): + + lp = [] + for e in lElts: + try: + lp.append(ShapeLoader.node_to_Polygon(e)) + except ValueError: + pass + contour = cascaded_union([p if p.is_valid else p.convex_hull for p in lp ]) + # print(contour.wkt) + try:spoints = ' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in contour.convex_hull.exterior.coords) + except: + try: spoints = ' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in contour.convex_hull.coords) + # JL got once a: NotImplementedError: Multi-part geometries do not provide a coordinate sequence + except: spoints = "" + return spoints + +def deleteRegionsinDOM(page,lRegionsNd): + [page.remove(c) for c in lRegionsNd] + +def main(sInputDir + , bVerbose=False): + + lSkippedFile = [] + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith("_du.mpxml")] + traceln(" - %d .mpxml files to process" % len(lsFilename)) + for sMPXml in lsFilename: + trace(" - .mpxml FILE : ", sMPXml) + if bVerbose: traceln() + + # 0 - load input file + doc = etree.parse(os.path.join(sInputDir,sMPXml)) + cluster2Region(doc,bVerbose) + + doc.write(os.path.join(sInputDir,sMPXml), + xml_declaration = True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + + +def propagateTypeToRegion(ndRegion): + """ + compute the most frequent type in the Textlines and assigns it to the new region + """ + dType=Counter() + for t in ndRegion: + dType[t.get('type')]+=1 + mc = dType.most_common(1) + if mc : + if mc[0][0]:ndRegion.set('type',mc[0][0]) + # structure {type:page-number;} + # custom="structure {type:page-number;}" + if mc[0][0]:ndRegion.set('custom',"structure {type:%s;}"%mc[0][0]) + + +def addRegionToDom(page,ipage,lc,bVerbose): + """ + create a dom node for each cluster + update DU_cluster for each Textline + """ + for ic,dC in enumerate(lc): + ndRegion = PageXml.createPageXmlNode('TextRegion') + + #update elements + lTL = lc[dC] + print (lTL) +# for id in c.get('content').split(): +# elt = page.xpath('.//*[@id="%s"]'%id)[0] +# elt.getparent().remove(elt) +# ndRegion.append(elt) +# lTL.append((elt)) + ndRegion.set('id',"p%d_r%d"%(ipage,ic)) + coords = PageXml.createPageXmlNode('Coords') + ndRegion.append(coords) + coords.set('points',getClusterCoords(lTL)) + propagateTypeToRegion(ndRegion) + + page.append(ndRegion) + +def getCLusters(ndPage): + dCluster=defaultdict(list) + lTL= ndPage.xpath(".//*[@DU_cluster]", namespaces=dNS) + for x in lTL:dCluster[x.get('DU_cluster')].append(x) + return dCluster + +def cluster2Region(doc, fTH=0.5,bVerbose=True): + """ + + """ + root = doc.getroot() + + # no use @DU_CLuster: + xpCluster = ".//pg:Cluster" + xpTextRegions = ".//pg:TextRegion" + + # get pages + for iPage, ndPage in enumerate(PageXml.xpath(root, "//pc:Page")): + # get cluster + dClusters= getCLusters(ndPage) #ndPage.xpath(xpCluster, namespaces=dNS) + lRegionsNd = ndPage.xpath(xpTextRegions, namespaces=dNS) + if bVerbose:traceln("\n%d clusters and %d regions found" %(len(dClusters),len(lRegionsNd))) + + addRegionToDom(ndPage,iPage+1,dClusters,bVerbose) + if bVerbose:traceln("%d regions created" %(len(dClusters))) + deleteRegionsinDOM(ndPage, lRegionsNd) + + return doc + + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Usage: %s + +""" % (sys.argv[0], 90) + + parser = OptionParser(usage=sUsage) + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + (options, args) = parser.parse_args() + + try: + sInputDir = args[0] + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + # all must be ok by now + lsDir = [sInputDir] + if not all(os.path.isdir(s) for s in lsDir): + for s in lsDir: + if not os.path.isdir(s): sys.stderr.write("Not a directory: %s\n"%s) + sys.exit(2) + + main(sInputDir, bVerbose=options.bVerbose) + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/tasks/compareReport.py b/TranskribusDU/tasks/compareReport.py index 15c3d01..5073c37 100644 --- a/TranskribusDU/tasks/compareReport.py +++ b/TranskribusDU/tasks/compareReport.py @@ -6,18 +6,7 @@ Copyright Naber Labs Europe(C) 2018 @author H. Déjean - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tasks/do_keep_if_text.py b/TranskribusDU/tasks/do_keep_if_text.py new file mode 100644 index 0000000..da4d229 --- /dev/null +++ b/TranskribusDU/tasks/do_keep_if_text.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +""" + Keep doc with more than given ratio of empty TextLine + + Copyright Naver Labs Europe(C) 2018 JL Meunier + + + + + Developed for the EU project READ. The READ project has received funding + from the European Union's Horizon 2020 research and innovation programme + under grant agreement No 674943. + +""" + + + + +import sys, os +from optparse import OptionParser +import shutil + +from lxml import etree + + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import PageXml +from tasks import _exit + + +def isTexted(sFilename, fRatio): + parser = etree.XMLParser(remove_blank_text=True) + doc = etree.parse(sFilename, parser) + + cntTxt, cnt = PageXml.countTextLineWithText(doc) + + fDocRatio = float(cntTxt) / cnt + + del doc + + if fDocRatio > fRatio: + return True + elif fDocRatio > 0: + traceln("Warning: %d texted out of %d (%.2f) %s" % (cntTxt, cnt, fDocRatio, sFilename)) + + return False +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + usage = """ """ + version = "v.01" + parser = OptionParser(usage=usage, version="0.1") + parser.add_option("--ratio", dest='fRatio', action="store" + , type=float + , help="Keep doc with more than given ratio of empty TextLine" + , default=0.75) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + traceln(options) + + if len(args) == 2 and os.path.isdir(args[0]) and os.path.isdir(args[1]): + # ok, let's work differently... + sFromDir,sToDir = args + for s in os.listdir(sFromDir): + if not s.endswith("pxml"): pass + sFilename = sFromDir + "/" + s + if isTexted(sFilename, options.fRatio): + traceln(sFilename," --> ", sToDir) + shutil.copy(sFilename, sToDir) + else: + traceln(" skipping: ", sFilename) + else: + for sFilename in args: + if isTexted(sFilename, options.fRatio): + traceln("texted : %s"%sFilename) + else: + traceln("no text: %s"%sFilename) diff --git a/TranskribusDU/tasks/ecn_16Lay1Conv.json b/TranskribusDU/tasks/ecn_16Lay1Conv.json new file mode 100644 index 0000000..4bfc40d --- /dev/null +++ b/TranskribusDU/tasks/ecn_16Lay1Conv.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 4800, + "nconv_edge": 1, + "node_indim": -1, + "num_layers": 16, + "ratio_train_val": 0.3, + "patience":100, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ecn_1Lay1Conv.json b/TranskribusDU/tasks/ecn_1Lay1Conv.json new file mode 100644 index 0000000..8153cec --- /dev/null +++ b/TranskribusDU/tasks/ecn_1Lay1Conv.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": -1, + "num_layers": 1, + "ratio_train_val": 0.2, + "patience":100, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ecn_4Lay1Conv.json b/TranskribusDU/tasks/ecn_4Lay1Conv.json new file mode 100644 index 0000000..ec9b351 --- /dev/null +++ b/TranskribusDU/tasks/ecn_4Lay1Conv.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 500, + "nconv_edge": 1, + "node_indim": 32, + "num_layers": 4, + "ratio_train_val": 0.3, + "patience":10, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ecn_8Lay1Conv.json b/TranskribusDU/tasks/ecn_8Lay1Conv.json new file mode 100644 index 0000000..1ced732 --- /dev/null +++ b/TranskribusDU/tasks/ecn_8Lay1Conv.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": -1, + "num_layers": 8, + "ratio_train_val": 0.2, + "patience":100, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ecn_8Lay1ConvLR.json b/TranskribusDU/tasks/ecn_8Lay1ConvLR.json new file mode 100644 index 0000000..8cb0017 --- /dev/null +++ b/TranskribusDU/tasks/ecn_8Lay1ConvLR.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 800, + "nconv_edge": 1, + "node_indim": -1, + "num_layers": 8, + "ratio_train_val": 0.5, + "patience":100, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ecn_8Lay1Conv_dropout.json b/TranskribusDU/tasks/ecn_8Lay1Conv_dropout.json new file mode 100644 index 0000000..d6ff942 --- /dev/null +++ b/TranskribusDU/tasks/ecn_8Lay1Conv_dropout.json @@ -0,0 +1,19 @@ +{ + "ecn_learner_config": + { + "name":"8Lay1Conv", + "dropout_rate_edge": 0.0, + "dropout_rate_edge_feat": 0.0, + "dropout_rate_node": 0.0, + "lr": 0.001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": -1, + "num_layers": 8, + "ratio_train_val": 0.2, + "patience":100, + "activation_name":"relu", + "stack_convolutions": false + } +} diff --git a/TranskribusDU/tasks/ensemble.json b/TranskribusDU/tasks/ensemble.json new file mode 100644 index 0000000..4add23a --- /dev/null +++ b/TranskribusDU/tasks/ensemble.json @@ -0,0 +1,74 @@ +{ + "_comment": "1 relu and 1 tanh models, twice, defined in a configuration file", + "ratio_train_val": 0.2, + "ecn_ensemble": [ + { + "type": "ecn", + "name": "default_8Lay1Conv_A", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.0001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": 64, + "num_layers": 8, + "ratio_train_val": 0.1, + "patience": 50, + "activation_name": "relu", + "stack_convolutions": false + }, + { + "type": "ecn", + "name": "default_8Lay1Conv_A", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.0001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": 64, + "num_layers": 8, + "ratio_train_val": 0.1, + "patience": 50, + "activation_name": "tanh", + "stack_convolutions": false + }, + { + "type": "ecn", + "name": "default_8Lay1Conv_B", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.0001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": 64, + "num_layers": 8, + "ratio_train_val": 0.1, + "patience": 50, + "activation_name": "relu", + "stack_convolutions": false + }, + { + "type": "ecn", + "name": "default_8Lay1Conv_B", + "dropout_rate_edge": 0.2, + "dropout_rate_edge_feat": 0.2, + "dropout_rate_node": 0.2, + "lr": 0.0001, + "mu": 0.0001, + "nb_iter": 1200, + "nconv_edge": 1, + "node_indim": 64, + "num_layers": 8, + "ratio_train_val": 0.1, + "patience": 50, + "activation_name": "tanh", + "stack_convolutions": false + } + ] +} \ No newline at end of file diff --git a/TranskribusDU/tasks/intersect_cluster.py b/TranskribusDU/tasks/intersect_cluster.py new file mode 100644 index 0000000..3c7f5c9 --- /dev/null +++ b/TranskribusDU/tasks/intersect_cluster.py @@ -0,0 +1,278 @@ +# -*- coding: utf-8 -*- + +""" +We expect XML file with cluster defined by several algo. +For each Page: + We intersect the cluster of one algo with cluster of the other and + We generate new clusters named after the algo names, e.g. (A_I_B) + +Overwrite the input XML files, adding new cluster definitions + +Created on 9/9/2019 + +Copyright NAVER LABS Europe 2019 + +@author: JL Meunier +""" + +import sys, os +from optparse import OptionParser + +from lxml import etree + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from util.Shape import ShapeLoader +from xml_formats.PageXml import PageXml + +# ---------------------------------------------------------------------------- +xpCluster = ".//pg:Cluster" +# sFMT = "(%s_∩_%s)" pb with visu +sFMT = "(%s_I_%s)" +sAlgoAttr = "algo" +xpPage = ".//pg:Page" +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + +class Cluster: + cnt = 0 + + def __init__(self, name, setID, shape=None): + self.name = name + self.setID = setID + self.shape = shape + # self.node = ... the load method can set a .node attribute pointing to the DOM node + + def getSetID(self): return self.setID + + def __len__(self): return len(self.setID) + + @classmethod + def remove(cls, ndPage, sAlgo): + """ + Given an algo, remove all its clusters from a page + """ + i = 0 + for nd in ndPage.xpath(xpCluster+"[@%s='%s']"%(sAlgoAttr, sAlgo) + , namespaces=dNS): + ndPage.remove(nd) + i += 1 + return i + + @classmethod + def load(cls, ndPage, sAlgo, bNode=False): + """ + Given an algo, load all its cluster from the page. + Compute their shape, if not provided in the XML, as a minimum rotated rectangle + """ + l = [] + for nd in ndPage.xpath(xpCluster+"[@%s='%s']"%(sAlgoAttr, sAlgo) + , namespaces=dNS): + c = cls.loadClusterNode(ndPage, nd, sAlgo) + if not c is None: + if bNode: c.node = nd + l.append(c) + return l + + @classmethod + def loadClusterNode(cls, ndPage, nd, sAlgo, bComputeShape=True): + """ + Load a cluster from its XML node + Compute its shape, if not provided in the XML, as a minimum rotated rectangle + """ + name = nd.get("name") + if name is None: + name = "%s_%d"%(sAlgo, cls.cnt) + cls.cnt += 1 + nd.set("name", name) + setID = set(nd.get("content").split()) + if bool(setID): + try: + shape = ShapeLoader.node_to_Polygon(nd) + except IndexError: + if bComputeShape: + shape = cls.computeShape(ndPage, setID) + else: + shape = None + return cls(name, setID, shape) + else: + return None + + @classmethod + def store(cls, ndPage, lCluster, sAlgo): + """ + Store those "qlgo" clusters in the page node + """ + ndPage.append(etree.Comment("\nClusters created by cluster intersection\n")) + + for c in lCluster: + ndPage.append(c.makeClusterNode(sAlgo)) + + def makeClusterNode(self, sAlgo): + """ + Create an XML node reflecting the cluster + """ + ndCluster = PageXml.createPageXmlNode('Cluster') + ndCluster.set("name", self.name) + ndCluster.set("algo", sAlgo) + # add the space separated list of node ids + ndCluster.set("content", " ".join(self.setID)) + ndCoords = PageXml.createPageXmlNode('Coords') + ndCluster.append(ndCoords) + if self.shape is None: + ndCoords.set('points', "") + else: + ndCoords.set('points', ShapeLoader.getCoordsString(self.shape)) + ndCluster.tail = "\n" + return ndCluster + + @classmethod + def intersect(cls, one, other): + """ + return None or a cluster made by intersecting two cluster + the shape of the intersection if the intersection of shapes, or None if not applicable + """ + setID = one.setID.intersection(other.setID) + if bool(setID): + try: + shapeInter = one.shape.intersection(other.shape) + except ValueError: + shapeInter = None + return cls(sFMT % (one.name, other.name), setID, shapeInter) + else: + return None + + @classmethod + def computeShape(cls, ndPage, setID, bConvexHull=False): + """ + compute a shape for this cluster, as the minimum rotated rectangle of its content + or optionally as the convex hull + """ + # let's find the nodes and compute the shape + lNode = [ndPage.xpath(".//*[@id='%s']"%_id, namespaces=dNS)[0] for _id in setID] + return ShapeLoader.convex_hull(lNode, bShapelyObject=True) \ + if bConvexHull \ + else ShapeLoader.minimum_rotated_rectangle(lNode, bShapelyObject=True) + + +def main(sInputDir, sAlgoA, sAlgoB, bShape=False, bConvexHull=False, bVerbose=False): + sAlgoC = sFMT % (sAlgoA, sAlgoB) + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith("_du.pxml") or name.endswith("_du.mpxml")] + traceln(" - %d files to process, to produce clusters '%s'" % ( + len(lsFilename) + , sAlgoC)) + + for sFilename in lsFilename: + sFullFilename = os.path.join(sInputDir, sFilename) + traceln(" - FILE : ", sFullFilename) + cntCluster, cntPage = 0, 0 + doc = etree.parse(sFullFilename) + + for iPage, ndPage in enumerate(doc.getroot().xpath(xpPage, namespaces=dNS)): + nRemoved = Cluster.remove(ndPage, sAlgoC) + + lClusterA = Cluster.load(ndPage, sAlgoA) + lClusterB = Cluster.load(ndPage, sAlgoB) + + if bVerbose: + trace("Page %d : (%d clusters REMOVED), %d cluster '%s' %d clusters '%s'" %(iPage+1 + , nRemoved + , len(lClusterA), sAlgoA + , len(lClusterB), sAlgoB)) + + lClusterC = [] + for A in lClusterA: + for B in lClusterB: + C = Cluster.intersect(A, B) + if not C is None: + lClusterC.append(C) + + if bVerbose: traceln( " -> %d clusters" % (len(lClusterC))) + if bShape or bConvexHull: + for c in lClusterC: + c.shape = Cluster.computeShape(ndPage, c.setID, bConvexHull=bConvexHull) + + cntCluster += len(lClusterC) + cntPage += 1 + + Cluster.store(ndPage, lClusterC, sAlgoC) + + doc.write(sFullFilename, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + + del doc + traceln(" %d clusters over %d pages" % (cntCluster, cntPage)) + + traceln(" done (%d files)" % len(lsFilename)) + + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Produce the intersection of two types of clusters, selected by their @algo attrbute. + +Usage: %s + +""" % (sys.argv[0]) + + parser = OptionParser(usage=sUsage) + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + parser.add_option("-s", "--shape", dest='bShape', action="store_true" + , help="Compute the shape of the intersection content as minimum rotated rectangle, instead of intersection of shapes") + parser.add_option("--hull", dest='bConvexHull', action="store_true" + , help="Compute the shape of the intersection content as convex hull, instead of intersection of shapes") + (options, args) = parser.parse_args() + + try: + sInputDir, sA, sB = args + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + + if not os.path.isdir(sInputDir): + sys.stderr.write("Not a directory: %s\n"%sInputDir) + sys.exit(2) + + # ok, go! + traceln("Input is : ", os.path.abspath(sInputDir)) + traceln("algo A is : ", sA) + traceln("algo B is : ", sB) + if options.bShape or options.bConvexHull: + traceln("Shape of intersections based on content!") + else: + traceln("Shape of intersections is the intersection of shapes!") + + main(sInputDir, sA, sB, options.bShape, options.bConvexHull, options.bVerbose) + + traceln("Input was : ", os.path.abspath(sInputDir)) + traceln("algo A was : ", sA) + traceln("algo B was : ", sB) + if options.bShape or options.bConvexHull: + trace("Shape of intersections based on content: ") + if options.bConvexHull: + traceln(" as a convex hull") + else: + traceln(" as a minimum rotated rectangle") + else: + traceln("Shape of intersections is the intersection of shapes!") + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/tasks/performCVLLA.py b/TranskribusDU/tasks/performCVLLA.py index 58763cc..79cbd96 100644 --- a/TranskribusDU/tasks/performCVLLA.py +++ b/TranskribusDU/tasks/performCVLLA.py @@ -10,18 +10,7 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -622,7 +611,7 @@ def regularTextLines(self,doc): #plg = Polygon(lXY) try: line=LineString(lXY) except ValueError: continue # LineStrings must have at least 2 coordinate tuples - topline=translate(line,yoff=-40).simplify(10) + topline=translate(line,yoff=-20) #iHeight = 20 # in pixel #x1,y1, x2,y2 = topline.getBoundingBox() if coord is not None: @@ -635,7 +624,7 @@ def regularTextLines(self,doc): coord.set('points',spoints) else: print (tl) -# print tl +# print tl def run(self,doc): """ diff --git a/TranskribusDU/tasks/project_GT_by_location.py b/TranskribusDU/tasks/project_GT_by_location.py new file mode 100644 index 0000000..335aa61 --- /dev/null +++ b/TranskribusDU/tasks/project_GT_by_location.py @@ -0,0 +1,475 @@ +# -*- coding: utf-8 -*- + +""" +Typically for use with ABP tables, to match the GT documents with their HTRed + counterpart. + +We have: +- an input collection obtained by downloading a Transkribus collection using + (PyClient) Transkribus_downloader.py +- a GT collection containing the definition of areas in each page. (Can be + table cells, or menu region, or whatever) + +We want +1 - to generate a new document, where the "elements of interest" (e.g TextLine) + of the input collection are matched against the GT areas by the location, + so that each element is either inserted in an area that matches or left + outside any area. +2 - (optionnally) to normalize the bounding area of the "element of interest" + This is done by making a box of predefined height from the Baseline + +Generate a new collection, with input documents enriched with GT areas. + +Any input document without GT counterpart is ignored. + +Created on 23 août 2019 + +Copyright NAVER LABS Europe 2019 +@author: JL Meunier +""" + +import sys, os +from optparse import OptionParser +from copy import deepcopy +from collections import defaultdict + +from lxml import etree +from numpy import argmax as argmax +from shapely.affinity import translate + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from util.Shape import ShapeLoader as ShapeLoader + +# ---------------------------------------------------------------------------- +iNORMALIZED_HEIGHT = 43 +xpELEMENT1 = ".//pg:TextRegion" +xpELEMENT2 = ".//pg:TextLine" + +xpAREA1 = ".//pg:TableRegion" +xpAREA2 = ".//pg:TableCell" + +xpBASELINE = ".//pg:Baseline" +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + +def main(sInputDir, sGTDir, sOutputDir + , xpElement1, xpElement2 + , xpArea1, xpArea2 + , bNorm, iNorm, bNormOnly + , bSep + , lsRmId + , bEval + , bWarm + , bVerbose=False): + + lSkippedFile = [] + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith(".mpxml") and not name.endswith("_du.mpxml")] + traceln(" - %d .mpxml files to process" % len(lsFilename)) + for sMPXml in lsFilename: + trace(" - .mpxml FILE : ", sMPXml) + if bVerbose: traceln() + + # -- find individual subfiles + sSubDir = os.path.join(sInputDir, sMPXml[:-len(".mpxml")]) + if os.path.isdir(sSubDir): + traceln(" (-> ", sSubDir, ")") + lsPXml = [os.path.basename(name) for name in os.listdir(sSubDir) if name.endswith(".pxml")] + if bVerbose: traceln("\t%d files to process"%len(lsPXml)) + else: + sSubDir = sInputDir + lsPXml = [sMPXml] + if bVerbose: traceln("\tprocessing the .mpxml file") + + # -- find GT... + for sInputXml in lsPXml: + trace("\t", sMPXml, " -- ", sInputXml) + + sGTFN = os.path.join(sGTDir, sInputXml) + if not os.path.isfile(sGTFN): + # maybe it is also a folder downloaded from Transkribus? + if os.path.isfile(os.path.join(sGTDir, sMPXml[:-len(".mpxml")], sInputXml)): + sGTFN = os.path.join(sGTDir, sMPXml[:-len(".mpxml")], sInputXml) + else: + # hummm, maybe it is a mpxml instead... :-/ + sGTFN = sGTFN[:-len(".pxml")] + ".mpxml" + if not os.path.isfile(sGTFN): + traceln(" *** NO GT *** file skipped ") + lSkippedFile.append(sInputXml) + continue + # ok GT file found + trace(" ...") + + # input Xml + sInFN = os.path.join(sSubDir, sInputXml) + sOutFN = os.path.join(sOutputDir, sInputXml) + + if bWarm and os.path.exists(sOutFN): + # check existence and freshness + t_in = os.path.getmtime(sInFN) + t_gt = os.path.getmtime(sGTFN) + t_out = os.path.getmtime(sOutFN) + if t_out > t_in and t_out > t_gt: + traceln("\t\t fresh output file found on disk: %s - skipping it!"%sOutFN) + continue + + # 0 - load input file + doc = etree.parse(sInFN) + + # 1 - normalize input elements + if bNorm: + doc = normaliseDocElements(doc, xpElement2, iNorm) + + # 2 - project GT + if not bNormOnly: + gtdoc = etree.parse(sGTFN) + if True: + doc = project_Elt_to_GT(gtdoc, doc + , xpElement1, xpElement2 + , xpArea2, bSep, lsRmId, bEval) + else: + doc = project_Areas_to_Input(gtdoc, doc + , xpElement1, xpElement2, xpArea1, xpArea2 + , bSep, lsRmId, bEval) + + # 3 - save + doc.write(sOutFN, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + + # done + + del doc + traceln(" done") + + + traceln(" - %d .pxml files skipped" % len(lSkippedFile)) + + +# --------------------------------------------------------------------------- +# Normalizing the box of TextElement, by translating a copy of the Baseline +def normaliseDocElements(doc, xpElement, iNorm): + for ndPage in doc.getroot().xpath("//pg:Page", namespaces=dNS): + for ndElt in ndPage.xpath(xpElement, namespaces=dNS): + try: + normaliseElement(ndElt, iNorm) + except NormaliseException as e: + traceln(str(e)) + traceln("Removing this element") + ndElt.getparent().remove(ndElt) + + return doc + + +class NormaliseException(Exception): + pass + + +def normaliseElement(nd, iNorm): + try: + ndBaseline = nd.xpath(xpBASELINE, namespaces=dNS)[0] + except IndexError: + raise NormaliseException("WARNING: skipped element normalisation: no Baseline: %s" % etree.tostring(nd)) + + try: + line = ShapeLoader.node_to_LineString(ndBaseline) + except ValueError: + raise NormaliseException("WARNING: skipped element normalisation: invalid Coords: %s" % etree.tostring(nd)) + topline = translate(line, yoff=-iNorm) + + # serialise both in circular sequence + spoints = ' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in line.coords) + lp=list(topline.coords) + lp.reverse() + spoints = spoints+ ' ' +' '.join("%s,%s"%(int(x[0]),int(x[1])) for x in lp) + + # ad-hoc way of setting the element coodinates + ndCoords = nd.xpath(".//pg:Coords", namespaces=dNS)[0] + ndCoords.set("points",spoints) + + return + +# --------------------------------------------------------------------------- +# projection of the GT area onto the doc + +class GTProjectionException(Exception): pass + +def project_Elt_to_GT(gtdoc, doc + , xpElement1, xpElement2 + , xpArea2 + , bSep, lsRmId, bEval + , fTH=0.5): + """ + Here we take the element out of the production file to put them in the GT + doc + + WE IGNORE xpArea1 (no need for it) + + We return the GT doc + """ + gtroot = gtdoc.getroot() + + # Evaluation + # we build a table of list of TextLineId from the GT to check this SW + # table_id -> row -> col -> list of element id + dTable = defaultdict(lambda : defaultdict(lambda : defaultdict(list))) + nOk, nTot = 0, 0 + + if lsRmId: + nbEltRemoved = 0 + for sRmId in lsRmId: + # for _nd in gtroot.xpath('//pg:*[@id="%s"]'%sRmId, namespaces=dNS): + for _nd in gtroot.xpath('//*[@id="%s"]'%sRmId): + _nd.getparent().remove(_nd) + nbEltRemoved += 1 + trace(" (Rm by ID: %d elements removed)" % nbEltRemoved) + + # remove all elements of interest from GT + # inside TableRegion, we have TextLine, outside we have TextRegion + for ndElt in gtroot.xpath(xpElement1, namespaces=dNS): + if bEval: + for ndElt2 in ndElt.xpath(xpElement2, namespaces=dNS): + dTable[None][None][None].append(ndElt2.get("id")) + ndElt.getparent().remove(ndElt) + for ndElt in gtroot.xpath(xpElement2, namespaces=dNS): + ndCell = ndElt.getparent() + if bEval: dTable[ndCell.getparent().get("id")][ndCell.get("row")][ndCell.get("col")].append(ndElt.get("id")) + ndCell.remove(ndElt) + if bEval: traceln("\npEvaluation mode") + + if bSep: + nbSepRemoved, nbSepAdded = 0, 0 + for _nd in gtroot.xpath('//pg:SeparatorRegion', namespaces=dNS): + _nd.getparent().remove(_nd) + nbSepRemoved += 1 + trace(" (Separators: %d removed" % nbSepRemoved) + + # project the GT areas, page by page + lNdPage = doc.getroot().xpath("//pg:Page", namespaces=dNS) + lNdPageGT = gtroot.xpath("//pg:Page", namespaces=dNS) + if len(lNdPage) != len(lNdPageGT): + raise GTProjectionException("GT and input have different numbers of pages") + + uniqID = 1 + for ndPage, ndPageGT in zip(lNdPage, lNdPageGT): + lNdArea2 = ndPageGT.xpath(xpArea2, namespaces=dNS) + loArea2 = [ShapeLoader.node_to_Polygon(nd) for nd in lNdArea2] + + for ndElt in ndPage.xpath(xpElement2, namespaces=dNS): + oElt = ShapeLoader.node_to_Polygon(ndElt) + + lOvrl = [oElt.intersection(o).area for o in loArea2] + iMax = argmax(lOvrl) + vMax = lOvrl[iMax] + + # where to add it? + if vMax > 0 and vMax / oElt.area > fTH: + # ok, this is a match + ndCell = lNdArea2[iMax] + # add it directly to the area2 (TableCell) + ndCell.append(deepcopy(ndElt)) + if bEval: + if ndElt.get("id") in dTable[ndCell.getparent().get("id")][ndCell.get("row")][ndCell.get("col")]: + nOk += 1 + else: + try: traceln('FAILED:in table: id="%s" "%s"' % (ndElt.get("id"), ndElt.xpath(".//pg:Unicode", namespaces=dNS)[0].text)) + except IndexError:traceln('FAILED:in table: id="%s" NOTEXT"' % (ndElt.get("id"))) + + else: + # add it outside of any area + bestNd = ndPageGT + # add it in its own TextRegion + ndTR = etree.Element("TextRegion") + ndTR.set("id", "prjct_region_%d" % uniqID) + uniqID += 1 + ndTR.set("custom", "") + ndTR.append(deepcopy(ndElt.xpath("./pg:Coords", namespaces=dNS)[0])) + ndTR.append(deepcopy(ndElt)) + bestNd.append(ndTR) + if bEval: + if ndElt.get("id") in dTable[None][None][None]: + nOk += 1 + else: + try: traceln('FAILED:in table: id="%s" "%s"' % (ndElt.get("id"), ndElt.xpath(".//pg:Unicode", namespaces=dNS)[0].text)) + except IndexError:traceln('FAILED:in table: id="%s" NOTEXT"' % (ndElt.get("id"))) + + nTot += 1 + + if bSep: + for _nd in ndPage.xpath('//pg:SeparatorRegion', namespaces=dNS): + ndPageGT.append(deepcopy(_nd)) + nbSepAdded += 1 + if bSep: trace(", %d added.) " % nbSepAdded) + + if bEval: + traceln("-"*40) + trace(" - evaluation: %d ok out of %d = %.2f%%\n" % (nOk, nTot, 100*nOk / (nTot+0.0001))) + + return gtdoc + + +def project_Areas_to_Input(gtdoc, doc, xpElement, xpArea1, xpArea2, bSep, lsRmId, bEval): + """ + Here we extract teh areas and put them in the input file + The element must be moved to the right areas + we return the doc + """ + raise GTProjectionException("Not implemented") + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Typically for use with ABP tables, to match the GT documents with their HTRed + counterpart. +We want to extract the HTRed text and , optionally, the separators from a + Transkribus processed collection, and inject them in a GT collection, to + replace the GT text, (and possibly the GT separators). + +We have: +- an input collection obtained by downloading a Transkribus collection using + (PyClient) Transkribus_downloader.py +- a GT collection containing the definition of nested areas in each page. + (Can be table cells in a table region, or whatever) + The nesting has 2 levels for now. + +In term of nesting, we assume: + [not CURRENTLY - xpArea1 are under Page XML element (xpArea1 is IGNORED and USELESS) + - xpArea2 (TableCell) are nested under xpArea1 (TableRegion) + - xpElement1 are under Page XML element + - xpElement2 (TextLine) are either under xpElement1 (TextRegion) or under xpArea2 (TableCell) + - SeparatorRegion are under PAGE XML element + +We want +1 - to generate a new document, where the "elements of interest" (e.g TextLine) + of the input collection are matched against the GT areas by the location, + so that each element is either inserted in an area that matches or left + outside any area. +2 - (optionnally) to normalize the bounding area of the "element of interest" + This is done by making a box of predefined height from the Baseline, which + becomes the bottom side of the box. +3 - (optionnaly) to discard SeparatorRegion from the GT and get instead those + from Transkribus. + +This is done page by page, for each document. + +Generate a new collection, with input documents enriched with GT areas. + +Any input document without GT counterpart is ignored. + +Usage: %s + [--normalize (%d above the Baseline) + [--normalize_height = (this height above the Baseline) + [--normalize-only] + [--separator] replace GT SeparatorRegion by those from input. + [--xpElement1 = ] (defaults to "%s") + [--xpElement2 = ] (defaults to "%s") + [--xparea1 = ] (defaults to "%s") (CURRENTLY IGNORED and USELESS) + [--xparea2 = ] (defaults to "%s") + [--eval] + +""" % (sys.argv[0], iNORMALIZED_HEIGHT + , xpELEMENT1, xpELEMENT2 + , xpAREA1, xpAREA1) + + parser = OptionParser(usage=sUsage) + parser.add_option("--xpElement1", dest='xpElement1', action="store", type="string" + , help="xpath of the elements lvl1" + , default=xpELEMENT1) + parser.add_option("--xpElement2", dest='xpElement2', action="store", type="string" + , help="xpath of the elements lvl2 to project" + , default=xpELEMENT2) + parser.add_option("--xpArea1", dest='xpArea1', action="store", type="string" + , help="xpath of the areas level 1 in GT" + , default=xpAREA1) + parser.add_option("--xpArea2", dest='xpArea2', action="store", type="string" + , help="xpath of the areas level 2 (nested) in GT" + , default=xpAREA2) + parser.add_option("--normalize", dest='bNorm', action="store_true" + , help="normalise the box of elements of interest") + parser.add_option("--separator", dest='bSep', action="store_true" + , help="replace any separator by those from the Transkribus collection") + parser.add_option("--normalize_height", dest='iNormHeight', action="store", type="int" + , help="normalise the box of elements of interest") + parser.add_option("--normalize-only", dest='bNormOnly', action="store_true" + , help="only normalize, does not project GT") + parser.add_option("--rm_by_id", dest='lsRmId', action="append" + , help="Remove those elements from the output XML") + parser.add_option("--eval", dest='bEval', action="store_true" + , help="evaluation mode, pass GT as input!!") + parser.add_option("--warm", dest='bWarm', action="store_true" + , help="Warm mode: skipped input files with a fresh output already there") + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + (options, args) = parser.parse_args() + + try: + sInputDir, sGTDir, sOutputDir = args + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... normalization + bNorm = bool(options.bNorm) or bool(options.iNormHeight) or bool(options.bNormOnly) + iNorm = options.iNormHeight if bool(options.iNormHeight) else iNORMALIZED_HEIGHT + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + if not os.path.normpath(sGTDir).endswith("col") : sGTDir = os.path.join(sGTDir, "col") + if os.path.isdir(sInputDir) and os.path.isdir(sGTDir): + # create the output fodlers if required + if os.path.normpath(sOutputDir).endswith("col") : + pass # we expect the user knows what s/he does + else: + # try to create them + try: os.mkdir(sOutputDir); + except: pass + sOutputDir = os.path.join(sOutputDir, "col") + try: os.mkdir(sOutputDir); + except: pass + # all must be ok by now + lsDir = [sInputDir, sGTDir, sOutputDir] + if not all(os.path.isdir(s) for s in lsDir): + for s in lsDir: + if not os.path.isdir(s): sys.stderr.write("Not a directory: %s\n"%s) + sys.exit(2) + + # ok, go! + traceln("Input is : ", os.path.abspath(sInputDir)) + traceln("GT is in : ", os.path.abspath(sGTDir)) + traceln("Ouput in : ", os.path.abspath(sOutputDir)) + traceln("Elements lvl 1: ", repr(options.xpElement1)) + traceln("Elements lvl 2: ", repr(options.xpElement2)) + traceln("GT areas lvl 1 : " , repr(options.xpArea1)) + traceln("GT areas lvl 2 (nested) : " , repr(options.xpArea2)) + traceln("Normalise elements : ", bNorm) + traceln("Normalise to height : ", iNorm) + traceln("Get separators : ", options.bSep) + traceln("Remove elements with @id: ", options.lsRmId) + + if os.listdir(sOutputDir): traceln("WARNING: *** output folder NOT EMPTY ***") + + main(sInputDir, sGTDir, sOutputDir + , options.xpElement1, options.xpElement2 + , options.xpArea1, options.xpArea2 + , bNorm, iNorm, options.bNormOnly + , options.bSep + , options.lsRmId + , options.bEval + , options.bWarm + , options.bVerbose) + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/tasks/tabulate_cell_cluster.py b/TranskribusDU/tasks/tabulate_cell_cluster.py new file mode 100644 index 0000000..219ead8 --- /dev/null +++ b/TranskribusDU/tasks/tabulate_cell_cluster.py @@ -0,0 +1,644 @@ +# -*- coding: utf-8 -*- + +""" +We expect XML file with cluster defined by one algo. + +For each Page: + We tabulate the clusters (build a table where each cluster is a cell) + We compute the row, col, row_span, col_span attributes of each cluster + +Overwrite the input XML files, adding attributes to the cluster definitions + +If the cluster do not have a defined shape, we compute a shape based on a minimum_rotated_rectangle + +Created on 26/9/2019 + +Copyright NAVER LABS Europe 2019 + +@author: JL Meunier +""" + +import sys, os +from optparse import OptionParser +from collections import defaultdict +from lxml import etree + +import numpy as np +import shapely.ops +from shapely import affinity + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln, trace +from xml_formats.PageXml import PageXml + +from tasks.intersect_cluster import Cluster +from graph.Block import Block +from util.Shape import ShapeLoader + +# ---------------------------------------------------------------------------- +xpCluster = ".//pg:Cluster" +xpClusterEdge = ".//pg:ClusterEdge" +xpEdge = ".//pg:Edge" +# sFMT = "(%s_∩_%s)" pb with visu +sAlgoAttr = "algo" +xpPage = ".//pg:Page" +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + + +class TableCluster(Cluster, Block): + thTopAligned = 20 # a difference less than 20 pixel on y1 means top-aliogned + # scale BB by these ratio (horizontally and vertically) + scale_H = 0.66 # better if same as in DU_Table_Col_Cut + # scale_H = 1.0 # to get hard cases + scale_V = 1 # do not shrink + + cnt = 0 + + def __init__(self, name, setID, shape=None): + Cluster.__init__(self, name, setID, shape=shape) + # we do not __init__ Block - useless, we just need a few methods + self.dsEdge = defaultdict(set) # dic edge_type -> neighbours set + self.cnt = TableCluster.cnt + TableCluster.cnt += 1 + + @classmethod + def induceClusterEdge(cls, ndPage, lCluster): + """ + compute inter- cluster edges from inter- cluster-item edges + + no so good for horizontal edges... :-/ + """ + # revert dictionay itemID -Cluster + dCluster_by_Item = { x:c for c in lCluster for x in c.getSetID() } + for _nd in ndPage.xpath(xpEdge, namespaces=dNS): + _A, _B = _nd.get("src"), _nd.get("tgt") + _AC, _BC = dCluster_by_Item[_A], dCluster_by_Item[_B] + if _AC != _BC: + TableCluster.link(_AC, _BC, edge_type=_nd.get("type")) + del dCluster_by_Item + + @classmethod + def computeClusterEdge(cls, _ndPage, lCluster): + """ + compute edge using g2 method from class Block :-) + A bit computationally heavy, but safe code... + """ + lHEdge, lVEdge = Block.findPageNeighborEdges(lCluster, bShortOnly=False, iGraphMode=2) + for edge in lHEdge: + TableCluster.link(edge.A, edge.B, "HorizontalEdge") + for edge in lVEdge: + TableCluster.link(edge.A, edge.B, "VerticalEdge") + + @classmethod + def addEdgesToXml(cls, ndPage, sAlgo, lCluster): + cnt = 0 + ndPage.append(etree.Comment("\nInter-cluster edges by tabulate_cluster scale_H=%.2f sclae_V=%.2f\n" %( + cls.scale_H, cls.scale_V))) + + setEdges = set() + + for A in lCluster: + for edge_type, lLinked in A.dsEdge.items(): + for B in lLinked: + if A.cnt >= B.cnt: continue + if (A, B, edge_type) not in setEdges: + # ok, let's add the edge A <--> B + ndEdge = PageXml.createPageXmlNode("ClusterEdge") + ndEdge.set("src", A.name) + ndEdge.set("tgt", B.name) + ndEdge.set("type", edge_type) + ndEdge.set("algo", sAlgo) + if True: + ptA = A.shape.representative_point() + ptB = B.shape.representative_point() + + else: + ptA, ptB = shapely.ops.nearest_points(A.shape, B.shape) + PageXml.setPoints(ndEdge, list(ptA.coords) + list(ptB.coords)) + ndEdge.tail = "\n" + ndPage.append(ndEdge) + + setEdges.add((A, B, edge_type)) + cnt += 1 + del setEdges + + return cnt + + @classmethod + def removeEdgesFromXml(cls, ndPage): + """ + Given an algo, remove all its clusters from a page + """ + i = 0 + for nd in ndPage.xpath(xpClusterEdge, namespaces=dNS): + ndPage.remove(nd) + i += 1 + return i + + @classmethod + def link(cls, A, B, edge_type=""): + """ + record an edge between those 2 clusters + """ + assert A != B + A.dsEdge[edge_type].add(B) + B.dsEdge[edge_type].add(A) + + @classmethod + def computeClusterBoundingBox(cls, lCluster): + for c in lCluster: + c.setBB(c.shape.bounds) + assert c.x1 < c.x2 + assert c.y1 < c.y2 + if cls.scale_H != 1 or cls.scale_V != 1: + c.scaled_shape = affinity.scale(c.shape, xfact=cls.scale_H, yfact=cls.scale_V) + else: + c.scaled_shape = c.shape + + @classmethod + def setTableAttribute(self, ndPage, setID, sAttr1, s1, sAttr2=None, s2=None): + """ + set attributes such as "col" and "colSPan" of a set of objects given by their ID + """ + lNode = [ndPage.xpath(".//*[@id='%s']"%_id, namespaces=dNS)[0] for _id in setID] + for nd in lNode: + nd.set(sAttr1, str(s1)) + if bool(sAttr2): + nd.set(sAttr2, str(s2)) + + @classmethod + def tabulate(cls, ndPage, lCluster, bVerbose=False): + """ + Top-down tabulation in the 4 directions + """ + + cls.tabulate_top_down(lCluster) + for c in lCluster: + c.row1 = c.minrow + c.node.set("row", str(c.row1)) + maxRow = max(c.row1 for c in lCluster) + #c.node.set("col", str(c.mincol)) + #c.node.set("rowSpan", str(c.maxrow - c.minrow + 1)) + #c.node.set("colSpan", str(c.maxcol - c.mincol + 1)) + + cls.rotateClockWise90deg(lCluster, bVerbose=bVerbose) + cls.tabulate_top_down(lCluster) + for c in lCluster: + c.col1 = c.minrow + c.node.set("col", str(c.col1)) + maxCol = max(c.col1 for c in lCluster) + + cls.rotateClockWise90deg(lCluster, bVerbose=bVerbose) + cls.tabulate_top_down(lCluster) + for c in lCluster: + c.row2 = maxRow - c.minrow + rowSpan = str(1 + c.row2 - c.row1) + c.node.set("rowSpan", rowSpan) + cls.setTableAttribute(ndPage, c.getSetID(), "row", c.row1, "rowSpan", rowSpan) + + cls.rotateClockWise90deg(lCluster, bVerbose=bVerbose) + cls.tabulate_top_down(lCluster) + for c in lCluster: + c.col2 = maxCol - c.minrow + colSpan = str(1 + c.col2 - c.col1) + c.node.set("colSpan", colSpan) + cls.setTableAttribute(ndPage, c.getSetID(), "col", c.col1, "colSpan", colSpan) + + @classmethod + def tabulate_rows(cls, ndPage, lCluster, bVerbose=False): + """ + Top-down and bottom-up tabulations + """ + + cls.tabulate_top_down(lCluster) + + maxRow = max(c.minrow for c in lCluster) + traceln(" maxRow=", maxRow) + +# if False: +# for c in lCluster: +# c.row1 = c.minrow +# c.node.set("row", str(c.row1)) +# cls.rotateClockWise180deg(lCluster, bVerbose=bVerbose) +# cls.tabulate_top_down(lCluster) +# for c in lCluster: +# c.row2 = max(maxRow - c.minrow, c.row1) +# rowSpan = str(1 + c.row2 - c.row1) +# c.node.set("rowSpan", rowSpan) +# cls.setTableAttribute(ndPage, c.getSetID(), "row", c.row1, "rowSpan", rowSpan) +# elif False: +# for c in lCluster: +# c.node.set("row", str(c.minrow)) +# rowSpan = str(9) +# c.node.set("rowSpan", rowSpan) +# cls.setTableAttribute(ndPage, c.getSetID(), "row", c.minrow, "rowSpan", rowSpan) + # tabulate top-down, then compute the separators and use them for + # deciding the row and rowSpan + # we get a list of linear separators, to be reflected as SeparatorRegion + cls.map_to_rows(ndPage, maxRow, lCluster) + + for c in lCluster: + c.node.set("row", str(c.row1)) + rowSpan = str(1 + c.row2 - c.row1) + c.node.set("rowSpan", rowSpan) + cls.setTableAttribute(ndPage, c.getSetID(), "row", c.row1, "rowSpan", rowSpan) + + + @classmethod + def use_cut_columns(cls, ndPage): + """ + use the name of the cut cluster to compute the col + colSPan is always 1 in that case + """ + # + for ndCluster in ndPage.xpath(xpCluster+"[@algo='cut']", namespaces=dNS): + col = str(int(ndCluster.get("name")) - 1) + setID = set(ndCluster.get("content").split()) + ndCluster.set("col", col) + ndCluster.set("colSpan", "1") + cls.setTableAttribute(ndPage, setID, "col", col, "colSpan", "1") + + @classmethod + def tabulate_top_down(cls, lCluster): + """ + compute minrow and maxrow values + """ + for c in lCluster: + assert c.x1 <= c.x2 + assert c.y1 <= c.y2 + + step = 1 + step_max = len(lCluster) + 1 + + for c in lCluster: c.minrow = -1 + + lTodoCluster = lCluster + prevSetUpdated = None + bNoLoop = True + while lTodoCluster and bNoLoop: + setUpdated = set() + traceln(" - STEP %d"%step) + # since we keep increasing the minrow, its maximum value cannot + # exceed len(lCluster), which is reached with at most step_max steps + assert step <= step_max, "algorithm error" + + # visit all vertically from top cluster + lTodoCluster.sort(key=lambda o: o.y1) + # faster?? lCurCluster.sort(key=operator.attrgetter("x1")) +# print([c.name for c in lTodoCluster]) +# for i in [0, 1]: +# print(lCluster[i].name, " y1=", lCluster[i].y1, " y2=", lCluster[i].y2) + for c in lTodoCluster: + setUpdated.update(c.visitStackDown(0)) + # visit all, horizontally from leftest clusters + lTodoCluster.sort(key=lambda o: o.x1) + for c in lTodoCluster: + setUpdated.update(c.visitPeerRight()) + + lTodoCluster.sort(key=lambda o: o.x2, reverse=True) + for c in lTodoCluster: + setUpdated.update(c.visitPeerLeft()) + + if not prevSetUpdated is None and prevSetUpdated == setUpdated: + traceln(" - loop detected - stopping now.") + bNoLoop = False + prevSetUpdated = setUpdated + lTodoCluster = list(setUpdated) + traceln(" ... %d updated" % len(lTodoCluster)) + step += 1 + + if not bNoLoop: + # need to fix the problem... + # because of the loop, we have holes in the list of row numbers + lMinrow = list(set(c.minrow for c in lCluster)) + lMinrow.sort() + curRow = 0 + for iMinrow in range(len(lMinrow)): + minrow = lMinrow[iMinrow] + if minrow > curRow: + # missing row number... + delta = minrow - curRow + for c in lCluster: + if c.minrow >= curRow: + c.minrow -= delta + for j in range(iMinrow, len(lMinrow)): + lMinrow[j] = lMinrow[j] - delta + curRow += 1 + + def visitStackDown(self, minrow, setVisited=set()): + """ + e.g., stacking from top to bottom, we get a visit from upward, so we update our minrow accordingly + return the set of updated items + """ + #if self.name == "(6_I_agglo_345866)" and minrow > 17: print(self.name, minrow) + setUpdated = set() + + if minrow > self.minrow: + # the stack above us tells us about our minrow! + self.minrow = minrow + setUpdated.add(self) + + for c in self.dsEdge["VerticalEdge"]: + # make sure we go downward + # if c.y1 > self.y1: + # and that the edge is a valid one + # which implies the 1st condition! + if self.y2 < c.y1: + if self.minrow >= c.minrow: + # otherwise no need... + setUpdated.update(c.visitStackDown(self.minrow + 1, setVisited)) + elif self.y1 < c.y1: + # c starts within self... + # maybe there is skewing? + if self.scaled_shape.intersects(c.scaled_shape): + # since we do not increase minrow, we need to make sure + # we do not infinite loop... + # (increasing minrow forces us to move downward the page and to end at some point) + if self.minrow > c.minrow or not self in setVisited: + setVisited.add(self) + setUpdated.update(c.visitStackDown(self.minrow, setVisited)) + else: + # I believe one is mostly above the other + if self.minrow >= c.minrow: + setUpdated.update(c.visitStackDown(self.minrow + 1, setVisited)) + + return setUpdated + + def visitPeerRight(self): + """ + go from left to right, making sure the minrow is consistent with the geometric relationships + """ + setUpdated = set() + a = self + for b in self.dsEdge["HorizontalEdge"]: + # make sure we go in good direction: rightward + if a.x2 <= b.x1: + minrow = max(a.minrow, b.minrow) + bAB = TableCluster.isTopAligned(a, b) # top justified + bA = bAB or a.y1 > b.y1 # a below b + bB = bAB or a.y1 < b.y1 # a above b + + if bA and minrow > a.minrow: + a.minrow = minrow + setUpdated.add(a) + + if bB and minrow > b.minrow: + b.minrow = minrow + setUpdated.add(b) + setUpdated.update(b.visitPeerRight()) + return setUpdated + + def visitPeerLeft(self): + """ + go from left to right, making sure the minrow is consistent with the geometric relationships + """ + setUpdated = set() + a = self + for b in self.dsEdge["HorizontalEdge"]: + # make sure we go in good direction: leftward + if b.x2 <= a.x1: + minrow = max(a.minrow, b.minrow) + bAB = TableCluster.isTopAligned(a, b) # top justified + bA = bAB or a.y1 > b.y1 # a below b + bB = bAB or a.y1 < b.y1 # a above b + + if bA and minrow > a.minrow: + a.minrow = minrow + setUpdated.add(a) + + if bB and minrow > b.minrow: + b.minrow = minrow + setUpdated.add(b) + setUpdated.update(b.visitPeerRight()) + + return setUpdated + + @classmethod + def isTopAligned(cls, a, b): + return abs(a.y1 - b.y1) < cls.thTopAligned + + + @classmethod + def rotateClockWise90deg(cls, lCluster, bVerbose=True): + if bVerbose: traceln(" -- rotation 90° clockwise") + for c in lCluster: + c.x1, c.y1, c.x2, c.y2 = -c.y2, c.x1, -c.y1, c.x2 + c.dsEdge["HorizontalEdge"], c.dsEdge["VerticalEdge"] = c.dsEdge["VerticalEdge"], c.dsEdge["HorizontalEdge"] + return + + @classmethod + def rotateClockWise180deg(cls, lCluster, bVerbose=True): + if bVerbose: traceln(" -- rotation 180° clockwise") + for c in lCluster: + c.x1, c.y1, c.x2, c.y2 = -c.x2, -c.y2, -c.x1, -c.y1 + return + + @classmethod + def map_to_rows(cls, ndPage, maxRow, lCluster): + """ + find lienar separators separating rows + """ + # reflect each cluster by the highest point (highest ending points of baselines) + dMinYByRow = defaultdict(lambda :9999999999) + n = 2 * sum(len(c) for c in lCluster) + X = np.zeros(shape=(n, 2)) # x,y coordinates + i = 0 + for c in lCluster: + c.maxY = -1 + c.minY = 9999999999 + for _id in c.getSetID(): + """ + + + + ung. + + """ + nd = ndPage.xpath(".//*[@id='%s']/pg:Baseline"%_id, namespaces=dNS)[0] + ls = ShapeLoader.node_to_LineString(nd) + pA, pB = ls.boundary.geoms + minY = min(pA.y, pB.y) + c.minY = min(c.minY, minY) + c.maxY = max(c.maxY, max((pA.y, pB.y))) + dMinYByRow[c.minrow] = min(dMinYByRow[c.minrow], minY) + # for the linear separators + X[i,:] = (pA.x, pA.y) + i = i + 1 + X[i,:] = (pB.x, pB.y) + i = i + 1 + + # check consistency + for c in lCluster: + for i in range(maxRow, c.minrow, -1): + if c.minY > dMinYByRow[i]: + assert c.minrow < i + # how possible??? fix!! + c.minrow = i + break + + # compute row1 and row2 + for c in lCluster: + c.row1 = c.minrow + c.row2 = c.minrow + for i in range(0, maxRow+1): + if c.maxY > dMinYByRow[i]: + c.row2 = i + else: + break + + # now compute maxRow - 1 separators! + w = float(ndPage.get("imageWidth")) + Y = np.zeros(shape=(n,)) # labels +# lAB = [getLinearSeparator(X, np.clip(Y, row, row+1)) +# for row in range(maxRow-1)] + + for nd in ndPage.xpath(".//pg:SeparatorRegion[@algo]", namespaces=dNS): + ndPage.remove(nd) + + for row in range(maxRow+1): + Y0 = dMinYByRow[row] - 20 + Yw = Y0 + ndSep = PageXml.createPageXmlNode("SeparatorRegion") + ndSep.set("algo", "tabulate_rows") + ndCoords = PageXml.createPageXmlNode("Coords") + ndCoords.set("points", "%d,%d %d,%d" %(0, Y0, w, Yw)) + ndSep.append(ndCoords) + ndSep.tail = "\n" + ndPage.append(ndSep) + + return + + +def main(sInputDir, sAlgo, bCol=False, scale_H=None, scale_V=None, bVerbose=False): + + if not scale_H is None: TableCluster.scale_H = scale_H + if not scale_V is None: TableCluster.scale_V = scale_V + + traceln("scale_H=", TableCluster.scale_H) + traceln("scale_V=", TableCluster.scale_V) + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith("_du.pxml") or name.endswith("_du.mpxml")] + traceln(" - %d files to process, to tabulate clusters '%s'" % ( + len(lsFilename) + , sAlgo)) + lsFilename.sort() + for sFilename in lsFilename: + sFullFilename = os.path.join(sInputDir, sFilename) + traceln(" -------- FILE : ", sFullFilename) + cnt = 0 + doc = etree.parse(sFullFilename) + + for iPage, ndPage in enumerate(doc.getroot().xpath(xpPage, namespaces=dNS)): + lCluster = TableCluster.load(ndPage, sAlgo, bNode=True) # True to keep a pointer to the DOM node + + if bVerbose: + trace(" --- Page %d : %d cluster '%s' " %(iPage+1, len(lCluster), sAlgo)) + if len(lCluster) == 0: + traceln("*** NO cluster '%s' *** we keep this page unchanged"%sAlgo) + continue + _nbRm = TableCluster.removeEdgesFromXml(ndPage) + if bVerbose: + traceln("\n %d ClusterEdge removed"%_nbRm) + + TableCluster.computeClusterBoundingBox(lCluster) + + if True: + # edges are better this way! + lBB = [] + for c in lCluster: + lBB.append(c.getBB()) + c.scale(TableCluster.scale_H, TableCluster.scale_V) + TableCluster.computeClusterEdge(ndPage, lCluster) + for c, bb in zip(lCluster, lBB): + c.setBB(bb) + # for c in lCluster: c.scale(1.0/TableCluster.scale_H, 1.0/TableCluster.scale_V) + else: + # compute inter- cluster edges from inter- cluster-item edges + TableCluster.induceClusterEdge(ndPage, lCluster) + + # store inter-cluster edges + cntPage = TableCluster.addEdgesToXml(ndPage, sAlgo, lCluster) + if bVerbose: + traceln(" %d inter-cluster edges " %(cntPage)) + + # compute min/max row/col for each cluster + # WARNING - side effect on lCluster content and edges + if bCol: + TableCluster.tabulate(ndPage, lCluster, bVerbose=bVerbose) + else: + TableCluster.tabulate_rows(ndPage, lCluster, bVerbose=bVerbose) + TableCluster.use_cut_columns(ndPage) + + cnt += cntPage + traceln("%d inter-cluster edges" %(cnt)) + + + doc.write(sFullFilename, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + + del doc + + traceln(" done (%d files)" % len(lsFilename)) + + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Tabulate the clusters from given @algo and compute the row, col, row_span, col_span attributes of each cluster + +Usage: %s + +""" % (sys.argv[0]) + + parser = OptionParser(usage=sUsage) + parser.add_option("--scale_h", dest='fScaleH', action="store", type="float" + , help="objects are horizontally scaled by this factor") + parser.add_option("--scale_v", dest='fScaleV', action="store", type="float" + , help="objects are vertically scaled by this factor") + parser.add_option("--col", dest='bCol', action="store_true" + , help="Columns also tabulated instead of derived from 'cut' clusters") + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + (options, args) = parser.parse_args() + + try: + sInputDir, sA = args + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + + if not os.path.isdir(sInputDir): + sys.stderr.write("Not a directory: %s\n"%sInputDir) + sys.exit(2) + + # ok, go! + traceln("Input is : ", os.path.abspath(sInputDir)) + traceln("algo is : ", sA) + if options.bCol: + traceln("columns also tabulated") + else: + traceln("columns are those of projection profile") + + main(sInputDir, sA, bCol=options.bCol + , scale_H=options.fScaleH, scale_V=options.fScaleV + , bVerbose=options.bVerbose) + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/tasks/tabulate_final.py b/TranskribusDU/tasks/tabulate_final.py new file mode 100644 index 0000000..c5aef3a --- /dev/null +++ b/TranskribusDU/tasks/tabulate_final.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- + +""" +We expect XML file with TextLine having the row, col, rowSpan, colSpan attributes + +For each Page: + We delete any empty table (or complain if not empty) + We select TextLine with rowSPan=1 and colSpan=1 + We create one cell for each pair of row and col number + We inject the TexLine into its cell + We create a TableRegion to contain the cells + We delete empty regions + We resize non-empty regions + +We compute the cell and table geometries and store them. + +Created on 21/10/2019 + +Copyright NAVER LABS Europe 2019 + +@author: JL Meunier +""" + +import sys, os +from optparse import OptionParser +from collections import defaultdict +from lxml import etree + +from shapely.ops import cascaded_union + +try: #to ease the use without proper Python installation + import TranskribusDU_version +except ImportError: + sys.path.append( os.path.dirname(os.path.dirname( os.path.abspath(sys.argv[0]) )) ) + import TranskribusDU_version +TranskribusDU_version + +from common.trace import traceln +from xml_formats.PageXml import PageXml + +from util.Shape import ShapeLoader + +# ---------------------------------------------------------------------------- +xpPage = ".//pg:Page" +dNS = {"pg":"http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15"} +# ---------------------------------------------------------------------------- + + +def processRegions(ndPage,bVerbose=False): + """ + Delete empty regions + resize no empty regions + """ + lDel=[] + lndRegions = ndPage.xpath(".//pg:TextRegion", namespaces=dNS) + for ndRegion in lndRegions: + lTL= ndRegion.xpath(".//pg:TextLine", namespaces=dNS) + if lTL == []: + # to be deleted + lDel.append(ndRegion) + else: + #resize it + oHull = ShapeLoader.convex_hull(lTL, bShapelyObject=True) + PageXml.getChildByName(ndRegion,'Coords')[0].set("points", ShapeLoader.getCoordsString(oHull, bFailSafe=True)) +# contour = cascaded_union([p if p.is_valid else p.convex_hull for p in lTL ]) +# o = contour.minimum_rotated_rectangle +# ndRegion.getChildByName('Coords').set("points", ShapeLoader.getCoordsString(o, bFailSafe=True)) + + # delete empty regions + [ ndRegion.getparent().remove(ndRegion) for ndRegion in lDel] + + if bVerbose: + traceln(" - %d regions deleted"%(len(lDel))) + traceln(" - %d regions updated"%(len(lndRegions) - len(lDel))) + +class TableRegion: + + def __init__(self, pagenum, tablenum): + self.pagenum = pagenum + self.tablenum = tablenum + # (row, col) -> list of nodes + self._dCellNd = defaultdict(list) + + def addToCell(self, row, col, nd): + self._dCellNd[(row, col)].append(nd) + + def makeTableNode(self): + """ + Make a DOM tree for this table + """ + lK = self._dCellNd.keys() + lRow = list(set(_row for _row, _col in lK)) + lRow.sort() + lCol = list(set(_col for _row, _col in lK)) + lCol.sort() + + ndTable = PageXml.createPageXmlNode("TableRegion") + ndTable.set("id", "p%s_%s" % (self.pagenum, self.tablenum)) + ndTable.tail = "\n" + lCellShape = [] + lNdCell = [] + for row in lRow: + for col in lCol: + lNdText = self._dCellNd[(row, col)] + # + # + + if lNdText: + ndCell = PageXml.createPageXmlNode("TableCell") + ndCell.set("id", "p%s_t%s_r%s_c%s"%(self.pagenum, self.tablenum, row, col)) + + # shape of the cell + oHull = ShapeLoader.convex_hull(lNdText, bShapelyObject=True) + lCellShape.append(oHull) # keep those to compute table contour + + # Coords sub-element + ndCoords = PageXml.createPageXmlNode("Coords") + ndCoords.set("points", ShapeLoader.getCoordsString(oHull, bFailSafe=True)) + ndCoords.tail = "\n" + ndCell.append(ndCoords) + + # row="0" col="0" rowSpan="1" colSpan="1" leftBorderVisible="false" rightBorderVisible="false" topBorderVisible="false" bottomBorderVisible="false" + ndCell.set("row" , str(row)) + ndCell.set("rowSpan", "1") + ndCell.set("col" , str(col)) + ndCell.set("colSpan", "1") + ndCell.tail = "\n" + + for nd in lNdText: ndCell.append(nd) + + lNdCell.append(ndCell) + + # Table geometry + ndCoords = PageXml.createPageXmlNode("Coords") + contour = cascaded_union([p if p.is_valid else p.convex_hull for p in lCellShape ]) + o = contour.minimum_rotated_rectangle + ndCoords.set("points", ShapeLoader.getCoordsString(o, bFailSafe=True)) + ndCoords.tail = "\n" + ndTable.append(ndCoords) + + for nd in lNdCell: + ndTable.append(nd) + + return ndTable + + +def main(sInputDir, bForce=False, bVerbose=False): + + # filenames without the path + lsFilename = [os.path.basename(name) for name in os.listdir(sInputDir) if name.endswith("_du.pxml") or name.endswith("_du.mpxml")] + traceln(" - %d files to process, to tabulate clusters" % ( + len(lsFilename))) + lsFilename.sort() + for sFilename in lsFilename: + sFullFilename = os.path.join(sInputDir, sFilename) + traceln(" -------- FILE : ", sFullFilename) + cnt = 0 + doc = etree.parse(sFullFilename) + + for iPage, ndPage in enumerate(doc.getroot().xpath(xpPage, namespaces=dNS)): + + # find and delete any pre-existing table + # if bForce, then move any TextLMine under Page before tabe deletion + lNdTable = ndPage.xpath(".//pg:TableRegion", namespaces=dNS) + if bVerbose: + if bForce: + traceln(" - %d pre-existing table to be deleted, preserving its contents by moving it under Page node" % len(lNdTable)) + else: + traceln(" - %d pre-existing table to be deleted IF EMPTY" % len(lNdTable)) + for ndTable in lNdTable: + lNd = ndTable.xpath(".//pg:TextLine", namespaces=dNS) + if lNd: + if bForce: + for nd in lNd: + nd.getparent().remove(nd) + ndPage.append(nd) + else: + raise ValueError("Pre-existing Table not empty") + ndTable.getparent().remove(ndTable) + + # enumerate text, and add to cell + # ignore any text in col|row-spanning cells + table = TableRegion(iPage+1, 1) # only one table for now! + lNdText = ndPage.xpath('.//pg:TextLine[@rowSpan="1" and @colSpan="1"]', namespaces=dNS) + for ndText in lNdText: + ndText.getparent().remove(ndText) + table.addToCell( int(ndText.get("row")) + , int(ndText.get("col")) + , ndText) + + # make the ! + ndTable = table.makeTableNode() + # add it to the page + ndPage.append(ndTable) + + processRegions(ndPage,bVerbose) + + doc.write(sFullFilename, + xml_declaration=True, + encoding="utf-8", + pretty_print=True + #compression=0, #0 to 9 + ) + + del doc + + traceln(" done (%d files)" % len(lsFilename)) + + + +# ---------------------------------------------------------------------------- +if __name__ == "__main__": + + version = "v.01" + sUsage=""" +Create a TableRegion for non-spanning cells. +Rely on row, col, rowSpan, colSpan attributes of the TextLine + +Usage: %s + +""" % (sys.argv[0]) + + parser = OptionParser(usage=sUsage) + parser.add_option("-v", "--verbose", dest='bVerbose', action="store_true" + , help="Verbose mode") + parser.add_option("-f", "--force", dest='bForce', action="store_true" + , help="Force deletion of pre-existing tables, if not empty keeps its contents") + (options, args) = parser.parse_args() + + try: + [sInputDir] = args + except ValueError: + sys.stderr.write(sUsage) + sys.exit(1) + + # ... checking folders + if not os.path.normpath(sInputDir).endswith("col") : sInputDir = os.path.join(sInputDir, "col") + + if not os.path.isdir(sInputDir): + sys.stderr.write("Not a directory: %s\n"%sInputDir) + sys.exit(2) + + # ok, go! + traceln("Input is : ", os.path.abspath(sInputDir)) + + main(sInputDir, bForce=options.bForce, bVerbose=options.bVerbose) + + traceln("Done.") \ No newline at end of file diff --git a/TranskribusDU/test_install/test_install.py b/TranskribusDU/test_install/test_install.py index bd6d21e..3a35216 100644 --- a/TranskribusDU/test_install/test_install.py +++ b/TranskribusDU/test_install/test_install.py @@ -5,18 +5,7 @@ Copyright Xerox(C) 2016 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/tests/test_DU_ABPTABLE.py b/TranskribusDU/tests/test_DU_ABPTABLE.py index 99ed3af..82802a4 100644 --- a/TranskribusDU/tests/test_DU_ABPTABLE.py +++ b/TranskribusDU/tests/test_DU_ABPTABLE.py @@ -8,8 +8,6 @@ @author: meunier ''' -from __future__ import absolute_import, print_function - import sys import os.path @@ -19,7 +17,7 @@ sDATA_DIR = os.path.join(sTESTS_DIR, "data") sys.path.append(os.path.dirname(sTESTS_DIR)) -import crf.Graph +import graph.Graph import tasks.DU_ABPTable @@ -37,7 +35,7 @@ def __init__(self): self.pkl = False self.rm = False self.crf_njobs = 2 - self.crf_max_iter = 2 + self.max_iter = 2 self.crf_C = None self.crf_tol = None self.crf_inference_cache = None @@ -47,7 +45,7 @@ def __init__(self): self.applyY = None def test_ABPTable_train(): - crf.Graph.Graph.resetNodeTypes() + graph.Graph.Graph.resetNodeTypes() sModelDir = os.path.join(sTESTS_DIR, "models") sModelName = "test_ABPTable_train" diff --git a/TranskribusDU/util/CollectionSplitter.py b/TranskribusDU/util/CollectionSplitter.py index 24cee52..356851c 100644 --- a/TranskribusDU/util/CollectionSplitter.py +++ b/TranskribusDU/util/CollectionSplitter.py @@ -7,18 +7,7 @@ Copyright NAVER(C) 2019 Jean-Luc Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme diff --git a/TranskribusDU/util/Polygon.py b/TranskribusDU/util/Polygon.py index 7bb5395..6f512d0 100644 --- a/TranskribusDU/util/Polygon.py +++ b/TranskribusDU/util/Polygon.py @@ -6,18 +6,7 @@ Copyright Xerox(C) 2016 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/util/Shape.py b/TranskribusDU/util/Shape.py index 12047d9..e4395a2 100644 --- a/TranskribusDU/util/Shape.py +++ b/TranskribusDU/util/Shape.py @@ -4,20 +4,9 @@ Utilities to deal with the PageXMl 2D objects using shapely - Copyright NAVER(C) 2018 JL. Meunier + Copyright Xerox(C) 2018 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -28,6 +17,7 @@ import shapely.geometry as geom from shapely.prepared import prep +from shapely.ops import cascaded_union from rtree import index import numpy as np @@ -36,6 +26,77 @@ class ShapeLoader: + @classmethod + def getCoordsString(cls, o, bFailSafe=False): + """ + Produce the usual content of the "Coords" attribute, e.g.: + "3162,1205 3162,1410 126,1410 126,1205 3162,1205" + may raise an exception + """ + try: + lt2 = o.exterior.coords # e.g. [(0.0, 0.0), (1.0, 1.0), (1.0, 0.0)] + except: + if bFailSafe: + try: + lt2 = o.coords + except: + return "" + else: + lt2 = o.coords + return " ".join("%d,%d" % (a,b) for a,b in lt2) + + + @classmethod + def contourObject(cls, lNd): + """ + return the stringified list of coordinates of the contour + for the list of PageXml node. + e.g. "3162,1205 3162,1410 126,1410 126,1205 3162,1205" + return "" upon error + + if bShapelyObjecy is True, then return the Shapely object + raise an Exception upon error + """ + lp = [] + for nd in lNd: + try: + lp.append(ShapeLoader.node_to_Polygon(nd)) + except: + pass + + o = cascaded_union([p if p.is_valid else p.convex_hull for p in lp ]) + return o + + @classmethod + def minimum_rotated_rectangle(cls, lNd, bShapelyObject=False): + """ + return the stringified list of coordinates of the minimum rotated + rectangle for the list of PageXml node. + e.g. "3162,1205 3162,1410 126,1410 126,1205 3162,1205" + return "" upon error + + if bShapelyObjecy is True, then return the Shapely object + raise an Exception upon error + """ + contour = cls.contourObject(lNd) + o = contour.minimum_rotated_rectangle + return o if bShapelyObject else cls.getCoordsString(o, bFailSafe=True) + + @classmethod + def convex_hull(cls, lNd, bShapelyObject): + """ + return the stringified list of coordinates of the minimum rotated + rectangle for the list of PageXml node. + e.g. "3162,1205 3162,1410 126,1410 126,1205 3162,1205" + return "" upon error + + if bShapelyObjecy is True, then return the Shapely object + raise an Exception upon error + """ + contour = cls.contourObject(lNd) + o = contour.convex_hull + return o if bShapelyObject else cls.getCoordsString(o, bFailSafe=True) + @classmethod def node_to_Point(cls, nd): """ @@ -73,14 +134,18 @@ def node_to_SingleLine(cls, nd): return cls.LinearRegression(o) @classmethod - def node_to_Polygon(cls, nd): + def node_to_Polygon(cls, nd, bValid=True): """ Find the points attribute (either in the DOM node itself or in a children Coord node) Parse the points series Return a Polygon shapely object """ - return cls._shapeFromNodePoints(nd, geom.Polygon) + p = cls._shapeFromNodePoints(nd, geom.Polygon) + if bValid and not p.is_valid: + # making sure it is a valid shape + p = p.buffer(0) + return p @classmethod def children_to_LineString(cls, node, name, fun=None): @@ -393,7 +458,12 @@ def test_ShapeLoader(): o = ShapeLoader._shapeFromPoints("0,0 0,9", geom.LineString) assert o.length == 9 assert o.area == 0.0 - + +def test_ShapeLoader_Coords(): + s = "3162,1205 3162,1410 126,1410 3162,1205" + o = ShapeLoader._shapeFromPoints(s, geom.Polygon) + assert ShapeLoader.getCoordsString(o) == s + # ----------------------------------------------------------------------- def test_ShapePartition_object_above(capsys): with capsys.disabled(): diff --git a/TranskribusDU/util/dtw.py b/TranskribusDU/util/dtw.py index 6e25989..6e13a23 100644 --- a/TranskribusDU/util/dtw.py +++ b/TranskribusDU/util/dtw.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2019 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/util/hungarian.py b/TranskribusDU/util/hungarian.py index fac52ca..42b54f2 100644 --- a/TranskribusDU/util/hungarian.py +++ b/TranskribusDU/util/hungarian.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2019 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -84,4 +73,18 @@ def test_simple(): assert evalHungarian([(1,)], lref, 0.6) == (0, 1, 3) +def test_simple_unordered(): + + lref = [ (3,4), (1,2), (99,6) ] + + l1 = [ (2,1), (4,3), ( 5,6) ] + + assert evalHungarian(l1, l1, 0.4) == (3, 0, 0) + assert evalHungarian(l1, lref, 0.3) == (3, 0, 0) + assert evalHungarian(l1, lref, 0.6) == (2, 1, 1) + + l2 = [ (3,4), (1,2), (66,6), (99, 999)] + assert evalHungarian(l2, lref, 0.6) == (2, 2, 1) + + assert evalHungarian([(1,)], lref, 0.6) == (0, 1, 3) \ No newline at end of file diff --git a/TranskribusDU/util/iou.py b/TranskribusDU/util/iou.py index b394f79..ae91d09 100644 --- a/TranskribusDU/util/iou.py +++ b/TranskribusDU/util/iou.py @@ -7,18 +7,7 @@ Copyright Naver Labs Europe(C) 2019 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/util/jaccard.py b/TranskribusDU/util/jaccard.py index 5d2b833..3103dce 100644 --- a/TranskribusDU/util/jaccard.py +++ b/TranskribusDU/util/jaccard.py @@ -7,18 +7,7 @@ Copyright Naver Labs Europe(C) 2019 H. Déjean, JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/util/lcs.py b/TranskribusDU/util/lcs.py index c913176..baf9c13 100644 --- a/TranskribusDU/util/lcs.py +++ b/TranskribusDU/util/lcs.py @@ -37,53 +37,7 @@ def matchLCS(perc, t1, t2): return ((val >= perc), val) -def testlcs(self,X,Y,m,n): - L = [[0 for x in range(n+1)] for x in range(m+1)] - - # Following steps build L[m+1][n+1] in bottom up fashion. Note - # that L[i][j] contains length of LCS of X[0..i-1] and Y[0..j-1] - for i in range(m+1): - for j in range(n+1): - if i == 0 or j == 0: - L[i][j] = 0 - elif X[i-1] == Y[j-1]: - L[i][j] = L[i-1][j-1] + 1 - else: - L[i][j] = max(L[i-1][j], L[i][j-1]) - - # Following code is used to print LCS - index = L[m][n] - - # Create a character array to store the lcs string - lcs = [""] * (index+1) - lcs[index] = "" - lmapping = [] - # Start from the right-most-bottom-most corner and - # one by one store characters in lcs[] - i = m - j = n - while i > 0 and j > 0: - - # If current character in X[] and Y are same, then - # current character is part of LCS - if X[i-1] == Y[j-1]: - lcs[index-1] = X[i-1] - lmapping.append((i-1,j-1)) - i-=1 - j-=1 - index-=1 - - # If not same, then find the larger of two and - # go in the direction of larger value - elif L[i-1][j] > L[i][j-1]: - i-=1 - else: - j-=1 - - lmapping.reverse() - xx =[(X[x],Y[y]) for x,y in lmapping] - return xx - + #--------- LCS code # Return the length of the longest common string of a and b. def lcs(a, b): @@ -104,7 +58,6 @@ def lcs(a, b): else: curLcs = max(prevRow[j+1], curRow[j]) curRow[j+1] = curLcs - print (curRow) return curRow[na] def fastlcs(a,b,Dmax=None): diff --git a/TranskribusDU/util/masking.py b/TranskribusDU/util/masking.py index a489caa..bca7825 100644 --- a/TranskribusDU/util/masking.py +++ b/TranskribusDU/util/masking.py @@ -5,18 +5,7 @@ Copyright Naver Labs Europe(C) 2019 JL. Meunier - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -74,7 +63,7 @@ def applyMask2(lView, lViewMask): Assumes the input views do not overlap each other Garanties that the output view do not overlap each other """ - for a,b in lView: assert a < b, "invalid view: %s, %s" %(a,b) + for a,b in lView: assert a <= b, "invalid view: %s, %s" %(a,b) ovrl = 0 # total overlap with the masks # apply each mask in turn @@ -94,6 +83,8 @@ def applyMask2(lView, lViewMask): ovrl += (_right - _left) else: # keep it as it is + # filter our when a == b + #if a != b: lNewView.append( (a,b) ) lView = lNewView if not lView: break # stop if the view is empty! diff --git a/TranskribusDU/util/partitionEvaluation.py b/TranskribusDU/util/partitionEvaluation.py index aabfa62..079f349 100644 --- a/TranskribusDU/util/partitionEvaluation.py +++ b/TranskribusDU/util/partitionEvaluation.py @@ -9,18 +9,7 @@ copyright Naver Labs 2018 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/TranskribusDU/util/statSeparator.py b/TranskribusDU/util/statSeparator.py new file mode 100644 index 0000000..5d79d6b --- /dev/null +++ b/TranskribusDU/util/statSeparator.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +""" +Currently provides only the computation of a linear separator + +H. Déjean, JL Meunier, Copyright Naver Labs Europe 2019 +""" + +from sklearn import svm + + +def getLinearSeparator(X, Y): + """ + Linear separator + + return a,b so that the linear separator has the form Y = a X + b + """ + + #C = 1.0 # SVM regularization parameter + # clf = svm.SVC(kernel = 'linear', gamma=0.7, C=C ) + clf = svm.SVC(kernel = 'linear') + clf.fit(X, Y) + w = clf.coef_[0] + a = -w[0] / w[1] + b = - (clf.intercept_[0]) / w[1] + return a, b + + +def test_getLinearSeparator(): + import numpy as np + + lP = [(i, 10) for i in range(10)] + lV = [(i, -2) for i in range(10)] + X = np.array(lP+lV) + Y = np.array([1]*10 + [0]*10) + + a,b = getLinearSeparator(X, Y) + assert abs(a) < 0.001 + assert abs(b-4) < 0.001 + #print(a,b) + + lP = [(i, 10+i) for i in range(10)] + lV = [(i, -2+i) for i in range(10)] + X = np.array(lP+lV) + Y = np.array([1]*10 + [0]*10) + + a,b = getLinearSeparator(X, Y) + assert abs(a-1) < 0.001 + assert abs(b-4) < 0.001 + # print(a,b) + diff --git a/TranskribusDU/visu/MyFrame.py b/TranskribusDU/visu/MyFrame.py index 571f78f..9aa785b 100644 --- a/TranskribusDU/visu/MyFrame.py +++ b/TranskribusDU/visu/MyFrame.py @@ -9,6 +9,7 @@ from document import Document from config import Config +from deco import DecoImage # begin wxGlade: dependencies # end wxGlade @@ -97,10 +98,12 @@ def __init__(self, *args, **kwds): id_saveas = wx.NewId() id_close = wx.NewId() id_quit = wx.NewId() + id_imgfolder = wx.NewId() wxglade_tmp_menu.Append(id_load, "&Load Xml File", "", wx.ITEM_NORMAL) wxglade_tmp_menu.Append(id_reload, "&Re-load the Xml File", "", wx.ITEM_NORMAL) wxglade_tmp_menu.Append(id_save, "&Save Xml File", "", wx.ITEM_NORMAL) wxglade_tmp_menu.Append(id_saveas, "Save &As Xml File", "", wx.ITEM_NORMAL) + wxglade_tmp_menu.Append(id_imgfolder, "Select image folder", "", wx.ITEM_NORMAL) #MARCHE PAS wxglade_tmp_menu.Append(id_reloadini, "&Reload INI File", "", wx.ITEM_NORMAL) #wxglade_tmp_menu.Append(id_close, "&Close", "", wx.ITEM_NORMAL) @@ -129,6 +132,7 @@ def __init__(self, *args, **kwds): self.Bind(wx.EVT_MENU, self.OnMenu_ReloadINI, id=id_reloadini) self.Bind(wx.EVT_MENU, self.OnMenu_SaveXML, id=id_save) self.Bind(wx.EVT_MENU, self.OnMenu_SaveAsXML, id=id_saveas) + self.Bind(wx.EVT_MENU, self.OnMenu_ImgFolder, id=id_imgfolder) self.Bind(wx.EVT_MENU, self.OnMenu_Quit, id=id_quit) self.Bind(wx.EVT_MENU, self.OnMenu_Help, id=id_help) @@ -330,6 +334,18 @@ def OnMenu_SaveAsXML(self, event): dlg.Destroy() if ret: self.bModified = False + def OnMenu_ImgFolder(self, event): + curdir = os.path.dirname(self.doc.getFilename()) + if not curdir: curdir = os.getcwd() + dlg = wx.DirDialog (None, "Select the image folder", "", + wx.DD_DEFAULT_STYLE | wx.DD_DIR_MUST_EXIST) + dlg.CenterOnScreen() + val = dlg.ShowModal() + if val == wx.ID_OK: + DecoImage.sImageFolder = dlg.GetPath() + dlg.Destroy() + self.bModified = True + self.display_page() def OnMenu_Quit(self, event): if self.bModified: @@ -433,22 +449,37 @@ def OnToolbar_ChangePage(self, evt): def OnCanvas_RightMouse(self, obj): """Click on a widget in the canvas with the right mouse""" - menu = wx.Menu() - # get the id of the corresponding node - self.n = self.doc.obj_n[obj] - tree_id = wx.NewId() - self.Bind(wx.EVT_MENU, self.OnPopup_RightCanvas, id=tree_id) - menu.Append(tree_id, "XPath lab") - c = self.wysi.Canvas - pos = (c.PixelToWorld(obj.XY[0]), - c.PixelToWorld(c.GetSize()[1]-obj.XY[1])) - self.PopupMenu(menu, pos) - menu.Destroy() +# menu = wx.Menu() +# # get the id of the corresponding node +# self.n = self.doc.obj_n[obj] +# tree_id = wx.NewId() +# self.Bind(wx.EVT_MENU, self.OnPopup_RightCanvas, id=tree_id) +# menu.Append(tree_id, "XPath lab") +# c = self.wysi.Canvas +# pos = (c.PixelToWorld(obj.XY[0]), +# c.PixelToWorld(c.GetSize()[1]-obj.XY[1])) +# self.PopupMenu(menu, pos) +# menu.Destroy() + print("Clicked: ", obj) + try: + txt = etree.tostring(self.doc.obj_n[obj].getparent()) + except KeyError: + print("No deco associated") + return + txt = unicode(txt, sEncoding) + tip = wx.TipWindow(self, txt, maxLength=1200) + wx.FutureCall(30000, tip.Close) + def OnCanvas_LeftMouse(self, obj): - txt = etree.tostring(self.doc.obj_n[obj]) + print("Clicked: ", obj) + try: + txt = etree.tostring(self.doc.obj_n[obj]) + except KeyError: + print("No deco associated") + return txt = unicode(txt, sEncoding) - tip = wx.TipWindow(self, txt) + tip = wx.TipWindow(self, txt, maxLength=1200) wx.FutureCall(30000, tip.Close) def OnCanvas_LeftMouseDecoAction(self, obj): @@ -493,10 +524,18 @@ def OnPopup_RightCanvas(self, event): def cbkDecoCheckBox(self, event): # wxGlade: MyFrame. """enable or disbale a decoration type""" deco = self.dChekBox2Deco[ event.GetEventObject() ] - deco.setEnabled(event.IsChecked()) + + # try to only update the page, when the user adds a decoration, for display speedup + # of course, the deco will be on top of all others, even if not at bottom of list of deco + b = event.IsChecked() + deco.setEnabled(b) if self.doc: d = self.doc.displayed - self.display_page(d) + if b: + # deco was enabled + self.update_page(d, deco) + else: + self.display_page(d) def cbkDecoNext(self, event): """jump on the next page that has such a decoration""" @@ -606,7 +645,7 @@ def display_page(self, i=0, (x,y,w,h)=(None, None, None, None)): LineColor=self.config.page_border_color, FillColor=self.config.page_background_color, FillStyle="Solid") - self.doc.obj_n[page_rect] = self.current_page_node + # useless self.doc.obj_n[page_rect] = self.current_page_node page_rect.Bind(FloatCanvas.EVT_FC_RIGHT_DOWN, self.OnCanvas_RightMouse) page_rect.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouse) # page_rect.Bind(FloatCanvas.EVT_FC_ENTER_OBJECT, self.OnCanvas_Enter) @@ -629,14 +668,14 @@ def display_page(self, i=0, (x,y,w,h)=(None, None, None, None)): #let's bind on the first object of the list if lo: - obj = lo[0] - self.doc.obj_n[obj] = n - obj.Bind(FloatCanvas.EVT_FC_RIGHT_DOWN, self.OnCanvas_RightMouse) + for obj in lo: self.doc.obj_n[obj] = n + obj0 = lo[0] + obj0.Bind(FloatCanvas.EVT_FC_RIGHT_DOWN, self.OnCanvas_RightMouse) if deco.isActionable(): - self.doc.obj_deco[obj] = deco - obj.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouseDecoAction) + self.doc.obj_deco[obj0] = deco + obj0.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouseDecoAction) else: - obj.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouse) + obj0.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouse) # obj.Bind(FloatCanvas.EVT_FC_ENTER_OBJECT, self.OnCanvas_Enter) deco.endPage(self.current_page_node) @@ -654,3 +693,53 @@ def display_page(self, i=0, (x,y,w,h)=(None, None, None, None)): c.ZoomToBB() + def update_page(self, i, deco): + """Update the page in the interface by drawing a single deco (just enabled) + """ + assert deco.isEnabled() + + c = self.wysi.Canvas + # self.doc.new_page(i) + + try: + self.current_page_node = self.doc.getPageByIndex(i) + except IndexError: + dlg = wx.MessageDialog(self, message="This XML file has no such page (%dth '%s' element) Try passing another .ini file as application parameter."%(i+1, self.config.page_tag), + caption="Error", + style=wx.ICON_ERROR) + dlg.CenterOnScreen() + val = dlg.ShowModal() + dlg.Destroy() + return + + #Now let's decorate the page according to the configuration + ln = self.doc.xpathEval( deco.getMainXPath(), self.current_page_node ) + deco.beginPage(self.current_page_node) + + for n in ln: + #TODO: deal with that!!! + inc = 1 + try: + lo = deco.draw(c, n) + except: + lo = None + traceback.print_exc() + + #let's bind on the first object of the list + if lo: + # bind all objects... (since the click fails from time to time...) +# obj = lo[0] +# self.doc.obj_n[obj] = n + for obj in lo: self.doc.obj_n[obj] = n + obj0 = lo[0] + obj0.Bind(FloatCanvas.EVT_FC_RIGHT_DOWN, self.OnCanvas_RightMouse) + if deco.isActionable(): + self.doc.obj_deco[obj0] = deco + obj0.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouseDecoAction) + else: + obj0.Bind(FloatCanvas.EVT_FC_LEFT_DOWN, self.OnCanvas_LeftMouse) +# obj.Bind(FloatCanvas.EVT_FC_ENTER_OBJECT, self.OnCanvas_Enter) + deco.endPage(self.current_page_node) + + c.ZoomToBB() + diff --git a/TranskribusDU/visu/deco.py b/TranskribusDU/visu/deco.py index 08c1f1d..8d2bddc 100644 --- a/TranskribusDU/visu/deco.py +++ b/TranskribusDU/visu/deco.py @@ -6,6 +6,8 @@ import types, os from collections import defaultdict import glob +import logging +import random from lxml import etree #import cStringIO import wx @@ -90,19 +92,49 @@ def setXPathContext(self, xpCtxt): def xpathError(self, node, xpExpr, eExcpt, sMsg=""): """report an xpath error""" + try: + Deco._s_prev_xpath_error + except AttributeError: + Deco._s_prev_xpath_error = "" + Deco._prev_xpath_error_count = 0 + iMaxLen = 200 # to truncate the node serialization - print "-"*60 - print "--- XPath ERROR on class %s"%self.__class__ - print "--- xpath=%s" % xpExpr - print "--- Python Exception=%s" % str(eExcpt) - if sMsg: print "--- Info: %s" % sMsg + s = "-"*60 + s += "\n--- XPath ERROR on class %s"%self.__class__ + s += "\n--- xpath=%s" % xpExpr + s += "\n--- Python Exception=%s" % str(eExcpt) + if sMsg: s += "\n--- Info: %s" % sMsg + + if s == Deco._s_prev_xpath_error: + # let's not overload the console. + return + Deco._s_prev_xpath_error = s + + Deco._prev_xpath_error_count += 1 + if Deco._prev_xpath_error_count > 10: + return + try: sNode = etree.tostring(node) except: sNode = str(node) if len(sNode) > iMaxLen: sNode = sNode[:iMaxLen] + "..." - print "--- XML node = %s" % sNode - print "-"*60 + s += "\n--- XML node = %s" % sNode + s += "\n" + "-"*60 + "\n" + logging.warning(s) + + def warning(self, sMsg): + """report an xpath error""" + try: + Deco._s_prev_warning + except AttributeError: + Deco._s_prev_warning = "" + Deco._warning_count = 0 + # if sMsg != Deco._s_prev_warning and Deco._warning_count < 1000: + if sMsg != Deco._s_prev_warning: + logging.warning(sMsg) + Deco._warning_count += 1 + Deco._s_prev_warning = sMsg def toInt(cls, s): try: @@ -338,7 +370,7 @@ def getText(self, wxh, node): try: return eval('u"\\u%04x"' % int(sEncodedText, self.base)) except ValueError: - print "DecoUnicodeChar: ERROR: base=%d code=%s"%(self.base, sEncodedText) + logging.error("DecoUnicodeChar: ERROR: base=%d code=%s"%(self.base, sEncodedText)) return "" @@ -370,7 +402,7 @@ def draw(self, wxh, node): obj = wxh.AddScaledBitmap(img, (x,-y), h) lo.append(obj) except Exception, e: - print "DecoImageBox ERROR: File %s: %s"%(sFilePath, str(e)) + self.warning("DecoImageBox ERROR: File %s: %s"%(sFilePath, str(e))) lo.append( DecoRectangle.draw(self, wxh, node) ) return lo @@ -379,6 +411,8 @@ def draw(self, wxh, node): class DecoImage(DecoBBXYWH): """An image """ + # in case the use wants to specify it via the menu + sImageFolder = None def __init__(self, cfg, sSurname, xpCtxt): DecoBBXYWH.__init__(self, cfg, sSurname, xpCtxt) @@ -399,6 +433,20 @@ def draw(self, wxh, node): x,y,w,h,inc = self.runXYWHI(node) sFilePath = self.xpathToStr(node, self.xpHRef, "") if sFilePath: + if self.sImageFolder: + sCandidate = os.path.join(self.sImageFolder, sFilePath) + if os.path.exists(sCandidate): + sFilePath = sCandidate + else: + # maybe the file is in a subfolder ? + # e.g. "S_Aicha_an_der_Donau_004-03_0005.jpg" is in folder "S_Aicha_an_der_Donau_004-03" + try: + sDir = sFilePath[:sFilePath.rindex("_")] + sCandidate = os.path.join(self.sImageFolder, sDir, sFilePath) + if os.path.exists(sCandidate): + sFilePath = sCandidate + except ValueError: + pass if not os.path.exists(sFilePath): #maybe the image is in a folder with same name as XML file? (Transkribus style) sUrl = node.getroottree().docinfo.URL.decode('utf-8') # py2 ... @@ -421,7 +469,7 @@ def draw(self, wxh, node): bKO = False break if bKO: - print "WARNING: deco Image: file does not exists: '%s'"%sFilePath + self.warning("WARNING: deco Image: file does not exists: '%s'"%sFilePath) sFilePath = None if bool(sFilePath): img = wx.Image(sFilePath, wx.BITMAP_TYPE_ANY) @@ -432,7 +480,7 @@ def draw(self, wxh, node): obj = wxh.AddScaledBitmap(img, (x,-y), img.GetHeight()) lo.append(obj) except Exception, e: - print "DecoImage ERROR: File %s: %s"%(sFilePath, str(e)) + self.warning("DecoImage ERROR: File %s: %s"%(sFilePath, str(e))) return lo @@ -537,13 +585,20 @@ def __init__(self, cfg, sSurname, xpCtxt): def _getCoordList(self, node): sCoords = self.xpathToStr(node, self.xpCoords, "") + if not sCoords: + if node.get("id") is None: + self.warning("No coordinates: node = %s" % etree.tostring(node)) + else: + self.warning("No coordinates: node id = %s" % node.get("id")) + return [(0,0)] try: ltXY = [] for _sPair in sCoords.split(' '): (sx, sy) = _sPair.split(',') ltXY.append((Deco.toInt(sx), Deco.toInt(sy))) - except Exception, e: - print "ERROR: polyline coords are bad: '%s'"%sCoords + except Exception as e: + logging.error("ERROR: polyline coords are bad: '%s' -> '%s'" % ( + self.xpCoords, sCoords)) raise e return ltXY @@ -584,27 +639,31 @@ def _getFontSize(self, node, ltXY, txt, Family=wx.FONTFAMILY_TELETYPE): return iFontSize, ExtentX, ExtentY """ (x1, y1), (x2, y2) = self._coordList_to_BB(ltXY) - - dc = wx.ScreenDC() - # compute for font size of 24 and do proportional - dc.SetFont(wx.Font(24, Family, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL)) - Ex, Ey = dc.GetTextExtent("x") - try: - iFontSizeX = 24 * abs(x2-x1) / Ex / len(txt) - except: - print "absence of text: cannot compute font size along X axis" - iFontSizeX = 8 - iFontSizeY = 24 * abs(y2-y1) / Ey sFit = self.xpathToStr(node, self.xpFit, 'xy', bShowError=False) - if sFit == "x": - iFontSize = iFontSizeX - elif sFit == "y": - iFontSize = iFontSizeY - else: - iFontSize = min(iFontSizeX, iFontSizeY) - dc.SetFont(wx.Font(iFontSize, Family, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL)) - Ex, Ey = dc.GetTextExtent("x") - del dc + + try: + iFontSize = int(sFit) + Ex, Ey = None, None + except ValueError: + dc = wx.ScreenDC() + # compute for font size of 24 and do proportional + dc.SetFont(wx.Font(24, Family, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL)) + Ex, Ey = dc.GetTextExtent("x") + try: + iFontSizeX = 24 * abs(x2-x1) / Ex / len(txt) + except: + self.warning("absence of text: cannot compute font size along X axis") + iFontSizeX = 8 + iFontSizeY = 24 * abs(y2-y1) / Ey + if sFit == "x": + iFontSize = iFontSizeX + elif sFit == "y": + iFontSize = iFontSizeY + else: + iFontSize = min(iFontSizeX, iFontSizeY) + dc.SetFont(wx.Font(iFontSize, Family, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL)) + Ex, Ey = dc.GetTextExtent("x") + del dc return iFontSize, Ex, Ey @@ -624,7 +683,9 @@ def draw(self, wxh, node): iFontSize, Ex, Ey = self._getFontSize(node, ltXY, txt, Family=wx.FONTFAMILY_TELETYPE) - x, y = ltXY[0] + # x, y = ltXY[0] + (x, _y1), (_x2, y) = self._coordList_to_BB(ltXY) + obj = wxh.AddScaledText(txt, (x, -y+iFontSize/6), Size=iFontSize , Family=wx.FONTFAMILY_TELETYPE , Position='tl' @@ -845,7 +906,141 @@ def draw(self, wxh, node): lo.append(obj) return lo - + +class DecoClusterCircle(DecoREAD): + """ + [Cluster] + type=DecoClusterCircle + xpath=.//Cluster + xpath_content=@content + xpath_radius=40 + xpath_item_lxy=./pg:Coords/@points + xpath_LineWidth="1" + xpath_FillStyle="Transparent" + LineColors="BLUE SIENNA YELLOW ORANGE RED GREEN" + FillColors="BLUE SIENNA YELLOW ORANGE RED GREEN" + enabled=1 + """ + count = 0 + def __init__(self, cfg, sSurname, xpCtxt): + DecoREAD.__init__(self, cfg, sSurname, xpCtxt) + self.xpCluster = cfg.get(sSurname, "xpath") + self.xpContent = cfg.get(sSurname, "xpath_content") + self.xpRadius = cfg.get(sSurname, "xpath_radius") + self.xpLineWidth = cfg.get(sSurname, "xpath_LineWidth") + self.xpFillStyle = cfg.get(sSurname, "xpath_FillStyle") + self.lsLineColor = cfg.get(sSurname, "LineColors").split() + self.lsFillColor = cfg.get(sSurname, "FillColors").split() + #cached values + self._node = None + self._laxyr = None + + print "DecoClusterCircle lsLineColor = ", self.lsLineColor + print "DecoClusterCircle lsFillColor = ", self.lsFillColor + + def __str__(self): + s = "%s="%self.__class__ + s += "+(coords=%s)" % (self.xpCoords) + return s + + def getArea_and_CenterOfMass(self, lXY): + """ + https://fr.wikipedia.org/wiki/Aire_et_centre_de_masse_d'un_polygone + + return A, (Xg, Yg) which are the area and the coordinates (float) of the center of mass of the polygon + """ + if len(lXY) < 2: raise ValueError("Only one point: polygon area is undefined.") + + fA = 0.0 + xSum, ySum = 0, 0 + + + xprev, yprev = lXY[-1] + for x, y in lXY: + iTerm = xprev*y - yprev*x + fA += iTerm + xSum += iTerm * (xprev+x) + ySum += iTerm * (yprev+y) + xprev, yprev = x, y + if fA == 0.0: raise ValueError("surface == 0.0") + fA = fA / 2 + xg, yg = xSum/6/fA, ySum/6/fA + + if fA <0: + return -fA, (xg, yg) + else: + return fA, (xg, yg) + assert fA >0 and xg >0 and yg >0, "%s\t%s"%(lXY (fA, (xg, yg))) + return fA, (xg, yg) + + def draw(self, wxh, node): + """draw itself using the wx handle + return a list of created WX objects""" + + DecoClusterCircle.count = DecoClusterCircle.count + 1 + + lo = DecoREAD.draw(self, wxh, node) + if self._node != node: + self._laxyr = [] + #need to go thru each item + ndPage = node.xpath("ancestor::*[local-name()='Page']")[0] + sIds = self.xpathEval(node, self.xpContent)[0] + for sId in sIds.split(): + l = self.xpathEval(ndPage, './/*[@id="%s"]'%sId) + ndItem = l[0] + lxy = self._getCoordList(ndItem) + fA, (xg, yg) = self.getArea_and_CenterOfMass(lxy) + r = self.xpathToInt(ndItem, self.xpRadius, 1) + self._laxyr.append( (fA, xg, yg, r) ) + self._node = node + + if self._laxyr: + iMaxFC = len(self.lsFillColor) + iMaxLC = len(self.lsLineColor) + if False: + Nf = DecoClusterCircle.count + Nl = Nf + else: + Nf = random.randrange(iMaxFC) + Nl = random.randrange(iMaxFC) + + iLineWidth = self.xpathToInt(node, self.xpLineWidth, 1) + sFillStyle = self.xpathToStr(node, self.xpFillStyle, "Solid") + for (_a, x, y, r) in self._laxyr: + #draw a circle + sFillColor = self.lsFillColor[Nf % iMaxFC] + if self.lsLineColor: + sLineColor = self.lsLineColor[Nl % iMaxLC] + else: + sLineColor = sFillColor + obj = wxh.AddCircle((x, -y), r, + LineWidth=iLineWidth, + LineColor=sLineColor, + FillColor=sFillColor, + FillStyle=sFillStyle) +# obj = wxh.AddRectangle((x, -y), (20, 20), +# LineWidth=iLineWidth, +# LineColor=sLineColor, +# FillColor=sFillColor, +# FillStyle=sFillStyle) + + lo.append(obj) + + """ + lo = DecoBBXYWH.draw(self, wxh, node) + x,y,w,h,inc = self.runXYWHI(node) + sLineColor = self.xpathToStr(node, self.xpLineColor, "#000000") + iLineWidth = self.xpathToInt(node, self.xpLineWidth, 1) + sFillColor = self.xpathToStr(node, self.xpFillColor, "#000000") + sFillStyle = self.xpathToStr(node, self.xpFillStyle, "Solid") + obj = wxh.AddRectangle((x, -y), (w, -h), + LineWidth=iLineWidth, + LineColor=sLineColor, + FillColor=sFillColor, + FillStyle=sFillStyle) + """ + return lo + class DecoLink(Deco): """A link from x1,y1 to x2,y2 """ diff --git a/TranskribusDU/visu/mpxml_viewer.bat b/TranskribusDU/visu/mpxml_viewer.bat index c431345..5846fb7 100644 --- a/TranskribusDU/visu/mpxml_viewer.bat +++ b/TranskribusDU/visu/mpxml_viewer.bat @@ -1,2 +1,8 @@ -C:\python27\python.exe %0.py %0.ini %1 -rem set /p temp="Hit enter to continue" \ No newline at end of file +rem --- install python 2.7 +rem --- install wxpython version 2.9 +rem --- > pip install lxml numpy +rem --- to use: > C:\python27\python.exe mpxml_viewer.bat.py mpxml_viewer.bat.ini + +C:\Python27\python.exe %0.py %0.ini %1 + +rem set /p temp="Hit enter to continue" diff --git a/TranskribusDU/visu/mpxml_viewer.bat.ini b/TranskribusDU/visu/mpxml_viewer.bat.ini index ab27a82..592a2bd 100644 --- a/TranskribusDU/visu/mpxml_viewer.bat.ini +++ b/TranskribusDU/visu/mpxml_viewer.bat.ini @@ -16,9 +16,9 @@ decos=Image sprtr TextRegionRectangle TextLineRectangle Baseline TextLine_Unicode READ_Unicode READ_x_Unicode - sprtr MENU_section MENU_section_heading MENU_item MENU_other - sprtr MENU_Item_name MENU_Item_description MENU_Item_price MENU_Item_quantity MENU_Item_number - sprtr MENU_Rest_name MENU_Rest_address MENU_Rest_phone_number MENU_Rest_url MENU_Rest_hours +# sprtr MENU_section MENU_section_heading MENU_item MENU_other +# sprtr MENU_Item_name MENU_Item_description MENU_Item_price MENU_Item_quantity MENU_Item_number +# sprtr MENU_Rest_name MENU_Rest_address MENU_Rest_phone_number MENU_Rest_url MENU_Rest_hours # sprtr Word_Unicode Word_Plain TextLine_Unicode TextLine_Plain TextRegion_Unicode TextRegion_Plain # sprtr type:heading type:page-number type:marginalia type:header type:catch-word type:UNKNOWN @@ -28,7 +28,8 @@ decos=Image sprtr TextRegionRectangle TextLineRectangle Baseline row:number_as_text col:number_as_text sprtr TableRectangle CellRectangle PredictedCellRectangle - sprtr Separator + sprtr cut Separator Separator_rows +# Separator_S Separator_I # sprtr Bsln:S Bsln:I Bsln:O # sprtr Grid Grid+ # sprtr cut cut:S cut:B cut:I cut:E cut:other @@ -37,9 +38,13 @@ decos=Image sprtr TextRegionRectangle TextLineRectangle Baseline # sprtr TableRectangle CellRectangle # sprtr Separator #decos=Image ImageRectangle sprtr TextRegionRectangle TextLineRectangle RegionText LineText - sprtr Edge EdgeCon Cluster + sprtr Edge EdgeCon ClusterEdge ClusterEdge_H ClusterEdge_V + sprtr Cluster ClusterColor Cluster_cut ClusterColor_cut Cluster_agglo ClusterColor_agglo Cluster_I ClusterColor_I +# Cluster_edge ClusterColor_edge # HorizontalEdge VerticalEdge # Edge_BL Edge_LL + dbgTableRow dbgTableCol + dbgTableRow_agglo dbgTableCol_agglo #------------------ # Where the files are situated by default @@ -166,7 +171,7 @@ xpath_LineColor="SIENNA" xpath_FillStyle="Transparent" xpath_incr="0" enabled=0 -xpath_LineWidth=1 +xpath_LineWidth=2 [Baseline] type=DecoPolyLine @@ -287,7 +292,7 @@ xpath_lxy=./pg:Baseline/@points xpath_content=./pg:TextEquiv/pg:Unicode xpath_font_color="BLUE" xpath_LineColor="RED" -enabled=1 +enabled=0 [MENU_Item_description] xpath_label="Item-description" @@ -299,7 +304,7 @@ xpath_content=./pg:TextEquiv/pg:Unicode xpath_font_color="BLUE" xpath_LineColor="GREEN" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Item_price] xpath_label="Item-price" @@ -311,7 +316,7 @@ xpath_content=./pg:TextEquiv/pg:Unicode xpath_font_color="BLUE" xpath_LineColor="BLUE" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Item_quantity] xpath_label="Item-quantity" @@ -323,7 +328,7 @@ xpath_content=./pg:TextEquiv/pg:Unicode xpath_font_color="BLUE" xpath_LineColor="VIOLET" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Item_number] xpath_label="Item-number" @@ -335,7 +340,7 @@ xpath_content=./pg:TextEquiv/pg:Unicode xpath_font_color="BLUE" xpath_LineColor="BLACK" xpath_incr="0" -enabled=1 +enabled=0 # - - - - - - - - - - - [MENU_Rest_name] @@ -349,7 +354,7 @@ xpath_font_color="BLUE" xpath_LineColor="RED" xpath_background_color="LIGHT GREY" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Rest_address] xpath_label="Restaurant-address" @@ -362,7 +367,7 @@ xpath_font_color="BLUE" xpath_LineColor="GREEN" xpath_background_color="LIGHT GREY" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Rest_phone_number] xpath_label="Restaurant-phone-number" @@ -375,7 +380,7 @@ xpath_font_color="BLUE" xpath_LineColor="BLUE" xpath_background_color="LIGHT GREY" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Rest_hours] xpath_label="Restaurant-hours" @@ -388,7 +393,7 @@ xpath_font_color="BLUE" xpath_LineColor="VIOLET" xpath_background_color="LIGHT GREY" xpath_incr="0" -enabled=1 +enabled=0 [MENU_Rest_url] xpath_label="Restaurant-url" @@ -401,7 +406,7 @@ xpath_font_color="BLUE" xpath_LineColor="BLACK" xpath_background_color="LIGHT GREY" xpath_incr="0" -enabled=1 +enabled=0 # ----------------------------------------------------------------------------- @@ -417,7 +422,7 @@ xpath_w="0" xpath_h="0" xpath_font_color="BLUE" xpath_incr="0" -enabled=1 +enabled=0 # Here we try to separate the row from the col for a better display [row:number_as_text] @@ -431,7 +436,7 @@ xpath_w="0" xpath_h="0" xpath_font_color="BLUE" xpath_incr="0" -enabled=1 +enabled=0 [col:number_as_text] type=DecoText @@ -444,7 +449,7 @@ xpath_w="0" xpath_h="0" xpath_font_color="BLUE" xpath_incr="0" -enabled=1 +enabled=0 # [row_col:number_as_text] # type=DecoREADTextLine @@ -454,7 +459,7 @@ enabled=1 # xpath_lxy=./pg:Coords/@points # xpath_font_color="BLUE" # xpath_incr="0" -# enabled=1 +# enabled=0 # ----------------------------------------------------------------------------- @@ -502,6 +507,33 @@ xpath_FillStyle="Transparent" xpath_incr="0" enabled=0 +[Separator_S] +type=DecoPolyLine +xpath=.//pg:SeparatorRegion[@DU_Sep="S"] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="BLUE" +xpath_FillStyle="Transparent" +xpath_incr="0" +enabled=0 + +[Separator_I] +type=DecoPolyLine +xpath=.//pg:SeparatorRegion[@DU_Sep="I"] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="RED" +xpath_FillStyle="Transparent" +xpath_incr="0" +enabled=0 + +[Separator_rows] +type=DecoPolyLine +xpath=.//pg:SeparatorRegion[@algo] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="RED" +xpath_FillStyle="Transparent" +xpath_incr="0" +enabled=0 + #-------------------------------------------------------------------- [row:S] type=DecoClosedPolyLine @@ -510,7 +542,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#FFFF00" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [row:B] type=DecoClosedPolyLine @@ -519,7 +551,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#FF0000" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [row:I] type=DecoClosedPolyLine @@ -528,7 +560,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#99ff33" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [row:E] type=DecoClosedPolyLine @@ -537,7 +569,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="LIGHT BLUE" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [row:T] type=DecoClosedPolyLine @@ -546,7 +578,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#0000FF" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [row:M] type=DecoClosedPolyLine @@ -555,7 +587,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#FFFFFF" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [col:S] @@ -601,7 +633,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#ffcc00" xpath_LineWidth=3 xpath_incr="0" -enabled=1 +enabled=0 [header:D] type=DecoClosedPolyLine @@ -621,7 +653,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="purple" xpath_LineWidth=2 xpath_incr="2" -enabled=1 +enabled=0 [sep:o] type=DecoClosedPolyLine @@ -630,7 +662,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="LIGHT GREY" xpath_LineWidth=2 xpath_incr="2" -enabled=1 +enabled=0 [type:heading] type=DecoClosedPolyLine @@ -639,7 +671,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#FFFF00" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [type:page-number] type=DecoClosedPolyLine @@ -648,7 +680,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#ffcc00" xpath_LineWidth=4 xpath_incr="-2" -enabled=1 +enabled=0 [type:marginalia] type=DecoClosedPolyLine @@ -657,7 +689,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#99ff33" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [type:header] type=DecoClosedPolyLine @@ -666,7 +698,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#ffcc99" xpath_LineWidth=4 xpath_incr="-2" -enabled=1 +enabled=0 [type:catch-word] type=DecoClosedPolyLine @@ -675,7 +707,7 @@ xpath_lxy=./pg:Coords/@points xpath_LineColor="#ff99cc" xpath_LineWidth=2 xpath_incr="-2" -enabled=1 +enabled=0 [type:UNKNOWN] type=DecoClosedPolyLine @@ -787,7 +819,7 @@ xpath=.//pg:Edge xpath_lxy=./@points xpath_LineColor="PINK" xpath_FillStyle="Transparent" -xpath_LineWidth=1 +xpath_LineWidth=2 xpath_incr="0" enabled=0 @@ -797,10 +829,11 @@ xpath=.//pg:Edge[@label="continue"] xpath_lxy=./@points xpath_LineColor="PURPLE" xpath_FillStyle="Transparent" -xpath_LineWidth=1 +xpath_LineWidth=2 xpath_incr="0" enabled=0 +# ----------------------------------------------------------- [Cluster] type=DecoClosedPolyLine xpath=.//pg:Cluster @@ -809,8 +842,138 @@ xpath_LineColor="Orange" xpath_FillStyle="Transparent" xpath_LineWidth=2 xpath_incr="0" -enabled=1 +enabled=0 + +[ClusterColor] +type=DecoClusterCircle +xpath=.//pg:Cluster +xpath_content=@content +xpath_radius=100 +xpath_LineWidth="2" +xpath_FillStyle="Solid" +# REM no line color list => same line and fill color +# REM LineColors="" +LineColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +FillColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +enabled=0 + +[ClusterEdge] +type=DecoPolyLine +xpath=.//pg:ClusterEdge +xpath_lxy=./@points +xpath_LineColor="VIOLET" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 +[ClusterEdge_H] +type=DecoPolyLine +xpath=.//pg:ClusterEdge[@type="HorizontalEdge"] +xpath_lxy=./@points +xpath_LineColor="VIOLET" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 +[ClusterEdge_V] +type=DecoPolyLine +xpath=.//pg:ClusterEdge[@type="VerticalEdge"] +xpath_lxy=./@points +xpath_LineColor="VIOLET" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 + +[Cluster_cut] +type=DecoClosedPolyLine +xpath=.//pg:Cluster[@algo="cut"] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="RED" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 + +[ClusterColor_cut] +type=DecoClusterCircle +xpath=.//pg:Cluster[@algo="cut"] +xpath_content=@content +xpath_radius=120 +xpath_LineWidth="2" +xpath_FillStyle="Solid" +LineColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +FillColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +enabled=0 + + +[Cluster_edge] +type=DecoClosedPolyLine +xpath=.//pg:Cluster[@algo!="cut"] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="GREEN" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 + +[ClusterColor_edge] +type=DecoClusterCircle +xpath=.//pg:Cluster[@algo!="cut"] +xpath_content=@content +xpath_radius=70 +xpath_LineWidth="2" +xpath_FillStyle="Solid" +LineColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +FillColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +enabled=0 + + +[Cluster_I] +type=DecoClosedPolyLine +xpath=.//pg:Cluster[contains(@algo,"_I_")] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="#2E8B57" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 + +[ClusterColor_I] +type=DecoClusterCircle +xpath=.//pg:Cluster[contains(@algo, "_I_")] +xpath_content=@content +xpath_radius=70 +xpath_LineWidth="2" +xpath_FillStyle="Solid" +LineColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +# no line color list => same line and fill color +FillColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +enabled=0 + +[Cluster_agglo] +type=DecoClosedPolyLine +xpath=.//pg:Cluster[@algo="agglo"] +xpath_lxy=./pg:Coords/@points +xpath_LineColor="#808000" +xpath_FillStyle="Transparent" +xpath_LineWidth=2 +xpath_incr="0" +enabled=0 + +[ClusterColor_agglo] +type=DecoClusterCircle +xpath=.//pg:Cluster[@algo="agglo"] +xpath_content=@content +xpath_radius=80 +xpath_LineWidth="3" +xpath_FillStyle="Solid" +LineColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +FillColors=BLUE RED PINK TURQUOISE ORANGE PURPLE YELLOW FIREBRICK GREEN MAROON +enabled=0 + +# ------------------------------------------------------------- [HorizontalEdge] type=DecoPolyLine xpath=.//pg:Edge[@DU_type="HorizontalEdge"] @@ -860,7 +1023,7 @@ xpath_LineColor="#DCDCDC" xpath_FillStyle="Transparent" xpath_LineWidth=2 xpath_incr="-6" -enabled=1 +enabled=0 [cut:S] type=DecoClosedPolyLine @@ -910,4 +1073,51 @@ xpath_LineColor="purple" xpath_FillStyle="Transparent" xpath_LineWidth=2 xpath_incr="-6" -enabled=0 \ No newline at end of file +enabled=0 + + + +[dbgTableRow] +type=DecoREADTextLine +#xpath=.//pg:Cluster +xpath=.//pg:Cluster[@algo="(cut_I_agglo)"] +# xpath_fit_text_size indicate how to fit the text to the polygon. It is one of: x y xy +xpath_fit_text_size=36 +xpath_lxy=./pg:Coords/@points +xpath_content=concat(@row, " (", @rowSpan, ")") +xpath_font_color="BLUE" +xpath_incr="0" +enabled=0 + +[dbgTableCol] +type=DecoREADTextLine +# xpath=.//pg:Cluster +xpath=.//pg:Cluster[@algo="(cut_I_agglo)"] +# xpath_fit_text_size indicate how to fit the text to the polygon. It is one of: x y xy +xpath_fit_text_size=36 +xpath_lxy=./pg:Coords/@points +xpath_content=concat(@col, " (", @colSpan, ")") +xpath_font_color="RED" +xpath_incr="0" +enabled=0 + +[dbgTableRow_agglo] +type=DecoREADTextLine +xpath=.//pg:Cluster[@algo="agglo"] +# xpath_fit_text_size indicate how to fit the text to the polygon. It is one of: x y xy +xpath_fit_text_size=36 +xpath_lxy=./pg:Coords/@points +xpath_content=concat(@row, " (", @rowSpan, ")") +xpath_font_color="BLUE" +xpath_incr="0" +enabled=1 + +[dbgTableCol_agglo] +type=DecoREADTextLine +xpath=.//pg:Cluster[@algo="cut"] +# xpath_fit_text_size indicate how to fit the text to the polygon. It is one of: x y xy +xpath_fit_text_size=36 +xpath_lxy=./pg:Coords/@points +xpath_content=concat(@col, " (", @colSpan, ")") +xpath_font_color="RED" +xpath_incr="0" diff --git a/TranskribusDU/xml_formats/DS2PageXml.py b/TranskribusDU/xml_formats/DS2PageXml.py index f47dfb3..0f16b95 100644 --- a/TranskribusDU/xml_formats/DS2PageXml.py +++ b/TranskribusDU/xml_formats/DS2PageXml.py @@ -107,7 +107,8 @@ def DSPoint2PagePoints(self,sPoints): 451,246 451,1094 781,1094 781,246 """ - lPoints = [x for xx in sPoints.split(' ') for x in xx.split(',')] + #print (sPoints) + lPoints = sPoints.split(" ").split(',') lx= list(map(lambda x:1.0*float(x)*self.dpi/72.0, lPoints)) # order left right xx = list(zip(lx[0::2], lx[1::2])) @@ -163,7 +164,7 @@ def convertDSObject(self,DSObject,pageXmlParentNode): if DSObject.hasAttribute('points'): coordsNode.set('points',self.DSPoint2PagePoints(DSObject.getAttribute('points'))) else: - coordsNode.set('points', self.BB2Polylines(DSObject.getX(),DSObject.getY(), DSObject.getHeight(),DSObject.getWidth())) + coordsNode.set('points', self.BB2Polylines(DSObject.getX(),DSObject.getY(), DSObject.getHeight(),DSObject.getWidth())) domNode.append(coordsNode) for attr in ['custom', 'structure','col','type','DU_row','DU_header','DU_col']: @@ -288,7 +289,7 @@ def run(self,domDoc): conversion """ ODoc =XMLDSDocument() -# ODoc.lastPage=1 + # ODoc.lastPage=1 ODoc.loadFromDom(domDoc) lPageXmlDoc=[] lPages= ODoc.getPages() diff --git a/TranskribusDU/xml_formats/Page2DS.py b/TranskribusDU/xml_formats/Page2DS.py index 2a7c6ab..e85f93e 100644 --- a/TranskribusDU/xml_formats/Page2DS.py +++ b/TranskribusDU/xml_formats/Page2DS.py @@ -104,10 +104,10 @@ def regionBoundingBox(self,sList): lList = sList.split(' ') for x,y in [x.split(',') for x in lList]: - minx = min(minx,float(x)) - maxx = max(maxx,float(x)) - miny = min(miny,float(y)) - maxy = max(maxy,float(y)) + minx = min(minx,int(x)) + maxx = max(maxx,int(x)) + miny = min(miny,int(y)) + maxy = max(maxy,int(y)) return [minx,miny,maxy-miny,maxx-minx] def regionBoundingBox2010(self,lList): @@ -350,24 +350,7 @@ def getTable(self,tableNode): return dstable - def copyEdge(self,child): - """ - - """ - node = etree.Element('EDGE') - node.set('src',child.get('src')) - node.set('tgt',child.get('tgt')) - node.set('type',child.get('type')) - node.set('w',child.get('proba')) - node.set('label',child.get('label')) - lPoints = child.get('points') - lP = lPoints.split(' ') - if lP != []: - scaledP= [ list(map(lambda x: 72.0* float(x) / self.dpi , xy.split(','))) for xy in lP] - scaledP = " ".join([ "%.2f,%.2f"% (x,y) for (x,y) in scaledP]) - node.set('points',scaledP) - return node - + def createRegion(self,pnode): """ create REGION @@ -498,8 +481,6 @@ def convert2DS(self,mprimedoc,sDocID): imageHeight = 72 * (float(ipage.get("imageHeight")) / self.dpi) page.set("width",str(imageWidth)) page.set("height",str(imageHeight)) - page.set("imageWidth",str(imageWidth)) - page.set("imageHeight",str(imageHeight)) self.convertPage(ipage, page) self.addTagProcessToMetadata(dsdom) @@ -539,8 +520,6 @@ def run(self): imageHeight = 72 * (float(ipage.get("imageHeight")) / self.dpi) page.set("width",str(imageWidth)) page.set("height",str(imageHeight)) - page.set("imageWidth",str(imageWidth)) - page.set("imageHeight",str(imageHeight)) imgNode = etree.Element("IMAGE") imgNode.set("href",ipage.get("imageFilename")) imgNode.set("x","0") @@ -549,8 +528,8 @@ def run(self): imgNode.set("width",str(imageWidth)) page.append(imgNode) self.convertPage(ipage, page) -# except StopIteration, e: -# traceln("=== done.") + + self.addTagProcessToMetadata(dsdom) return dsdom diff --git a/TranskribusDU/xml_formats/PageXml.py b/TranskribusDU/xml_formats/PageXml.py index a806db8..ca9d115 100644 --- a/TranskribusDU/xml_formats/PageXml.py +++ b/TranskribusDU/xml_formats/PageXml.py @@ -9,6 +9,10 @@ @author: meunier ''' + + + + import os import datetime from copy import deepcopy @@ -128,18 +132,25 @@ def setMetadata(cls, doc, domNd, Creator, Comments=None): return the Metadata DOM node """ ndMetadata, ndCreator, _ndCreated, ndLastChange, ndComments = cls._getMetadataNodes(doc, domNd) - ndCreator.text = Creator + if bool(Creator): + if ndCreator.text: + ndCreator.text = ndCreator.text + "\n" + Creator + else: + ndCreator.text = Creator #The schema seems to call for GMT date&time (IMU) #ISO 8601 says: "If the time is in UTC, add a Z directly after the time without a space. Z is the zone designator for the zero UTC offset." #Python seems to break the standard unless one specifies properly a timezone by sub-classing tzinfo. But too complex stuff #So, I simply add a 'Z' ndLastChange.text = datetime.datetime.utcnow().isoformat()+"Z" - if Comments != None: + if bool(Comments): ## if not ndComments: #we need to add one! ## FutureWarning: The behavior of this method will change in future versions. Use specific 'len(elem)' or 'elem is not None' test instead. if ndComments is None : #we need to add one! ndComments = etree.SubElement(ndMetadata, cls.sCOMMENTS_ELT) - ndComments.text = Comments + if bool(ndComments.text): + ndComments.text = ndComments.text + "\n" + Comments + else: + ndComments.text = Comments return ndMetadata setMetadata = classmethod(setMetadata) @@ -346,7 +357,12 @@ def _getMetadataNodes(cls, doc=None, domNd=None): nd4 = nd3.getnext() if nd4 is not None: - if etree.QName(nd4.tag).localname not in [cls.sCOMMENTS_ELT,cls.sTranskribusMetadata_ELT] : raise ValueError("PageXMl mal-formed Metadata: LastChange element must be 3rd element") + if etree.QName(nd4.tag).localname not in [cls.sCOMMENTS_ELT,cls.sTranskribusMetadata_ELT] : raise ValueError("PageXMl mal-formed Metadata: expected a Transkribus metadata or some comment as 4th element") + if etree.QName(nd4.tag).localname == cls.sTranskribusMetadata_ELT: + nd4 = nd4.getnext() + if nd4 is not None: + if etree.QName(nd4.tag).localname != cls.sCOMMENTS_ELT : raise ValueError("PageXMl mal-formed Metadata: expected a comment element") + return domNd, nd1, nd2, nd3, nd4 _getMetadataNodes = classmethod(_getMetadataNodes) @@ -651,29 +667,44 @@ def _iter_splitMultiPageXml(cls, doc, bInPlace=True): #to jump to the PAGE sibling node (we do it now, defore possibly unlink...) node = metadataNd.getnext() + xmlPAGERoot.append(metadataNd) +# node = metadataNd.getnext() + xmlPAGERoot.append(node) + + """ + Hervé 28/05/2019: I comment since I don't understand + """ # #Add a copy of the METADATA node and sub-tree - if bInPlace: - metadataNd.getparent().remove(metadataNd) - xmlPAGERoot.append(metadataNd) - else: - newMetadataNd=deepcopy(metadataNd) - xmlPAGERoot.append(newMetadataNd) +# if bInPlace: +# # metadataNd.unlinkNode() +# metadataNd.getparent().remove(metadataNd) +# newRootNd.append(metadataNd) +# else: +# # newMetadataNd = metadataNd.copyNode(1) +# newMetadataNd=deepcopy(metadataNd) +# metadataNd.getparent().remove(metadataNd) +# newRootNd.append(newMetadataNd) # #jump to the PAGE sibling node # node = metadataNd.next - + while node is not None: # if node.type == "element": break # node = node.next if node.tag != etree.Comment: break node = node.getnext() if etree.QName(node.tag).localname != "Page": raise ValueError("Input multi-page PageXml for page %d should have a PAGE node after the METADATA node."%pnum) + #Add a copy of the PAGE node and sub-tree if bInPlace: - xmlPAGERoot.append(node) +# node.unlinkNode() +# newNode = newRootNd.addChild(node) + newRootNd.append(node) newNode= node else: +# newPageNd = node.copyNode(1) +# newNode = newRootNd.addChild(newPageNd) newNode = deepcopy(node) newRootNd.append(newNode) #Remove the prefix on the "id" attributes @@ -691,10 +722,8 @@ def _iter_splitMultiPageXml(cls, doc, bInPlace=True): # ctxt.xpathFreeContext() # for doc in lDocToBeFreed: doc.freeDoc() - raise StopIteration + return _iter_splitMultiPageXml = classmethod(_iter_splitMultiPageXml) - - # --- Metadata of PageXml -------------------------------- class Metadata: @@ -785,4 +814,4 @@ def __init__(self, Creator, Created, LastChange, Comments=None): print ("\t done: %s"%filename) print ("DONE") - \ No newline at end of file + diff --git a/TranskribusDU/xml_formats/PageXmlExtractor.py b/TranskribusDU/xml_formats/PageXmlExtractor.py index 63b7909..fb3a222 100755 --- a/TranskribusDU/xml_formats/PageXmlExtractor.py +++ b/TranskribusDU/xml_formats/PageXmlExtractor.py @@ -9,6 +9,10 @@ @author: meunier ''' + + + + import os from io import open import json @@ -70,7 +74,7 @@ def iterPageNumber(self): for a,b in self._ltiRange: for n in range(a,b+1): yield n - raise StopIteration + return # ----- def __str__(self): diff --git a/TranskribusDU/xml_formats/mpxml2pxml.py b/TranskribusDU/xml_formats/mpxml2pxml.py index fe69bb5..492df9e 100644 --- a/TranskribusDU/xml_formats/mpxml2pxml.py +++ b/TranskribusDU/xml_formats/mpxml2pxml.py @@ -31,18 +31,22 @@ (options, args) = parser.parse_args() try: - dir = args[0] - docid= args[1] + _dir = args[0] + _docid= args[1] except: parser.print_help() parser.exit(1, "") - sDocFilename = "%s%scol%s%s.mpxml" % (dir,os.sep,os.sep,docid) + sDocFilename = "%s%scol%s%s.mpxml" % (_dir,os.sep,os.sep,_docid) doc = etree.parse(sDocFilename) - for pnum, pageDoc in PageXml.MultiPageXml._iter_splitMultiPageXml(doc, bInPlace=False): - outfilename = "%s%s%s%s%s_%03d.pxml" % (dir,os.sep,options.destdir,os.sep,docid,pnum) + ## sDocFilename = "%s%scol%s%s.bar_mpxml" % (_dir,os.sep,os.sep,_docid) + ## doc = etree.parse(sDocFilename) + ## for pnum, pageDoc in PageXml.MultiPageXml._iter_splitMultiPageXml(doc, bInPlace=False): + for pnum, pageDoc in PageXml.MultiPageXml._iter_splitMultiPageXml(doc, bInPlace=True): + outfilename = "%s%s%s%s%s_%03d.pxml" % (_dir,os.sep,options.destdir,os.sep,_docid,pnum) print(outfilename) pageDoc.write(outfilename, xml_declaration ='UTF-8',encoding="utf-8", pretty_print = bool(options.bIndent)) + doc.freeDoc() print ("DONE") \ No newline at end of file diff --git a/TranskribusDU/xml_formats/tests/test_PageXml.py b/TranskribusDU/xml_formats/tests/test_PageXml.py index 6356f17..46413e8 100644 --- a/TranskribusDU/xml_formats/tests/test_PageXml.py +++ b/TranskribusDU/xml_formats/tests/test_PageXml.py @@ -149,6 +149,48 @@ def test_CreationPageXmlDocument(): doc= PageXml.createPageXmlDocument(creatorName='HerveforTest', filename='hervefortest.jpg', imgW=100, imgH=100) print(doc) +def test_countTextLineWithText(): + sXml = b""" + + + Tilla + 2016-08-18T13:35:08.252+07:00 + 2016-12-01T09:53:39.610+01:00 + + + + + + + + + 52. + + + + + + + + + + + + + + + + + + + + + """ + doc = etree.parse(BytesIO(sXml)) + + assert (1, 2) == PageXml.countTextLineWithText(doc) + return doc + if __name__ == "__main__": test_setMetadata() test_CreationPageXmlDocument() \ No newline at end of file diff --git a/usecases/ABP/src/ABPCSV2XML.py b/usecases/ABP/src/ABPCSV2XML.py index 069b5ca..db2aa31 100644 --- a/usecases/ABP/src/ABPCSV2XML.py +++ b/usecases/ABP/src/ABPCSV2XML.py @@ -13,18 +13,7 @@ copyright Xerox 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/usecases/ABP/src/ABPIEOntology.py b/usecases/ABP/src/ABPIEOntology.py index 35e7fed..3565e6c 100644 --- a/usecases/ABP/src/ABPIEOntology.py +++ b/usecases/ABP/src/ABPIEOntology.py @@ -1,28 +1,15 @@ # -*- coding: utf-8 -*- """ - ABP records IEOntology - + ABP Death records IEOntology Hervé Déjean cpy Xerox 2017, NLE 2017 death record - wedding record (for test) READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -40,68 +27,6 @@ from lxml import etree - - -class BaptismRecord(recordClass): - """ - name (firstname only possible) +death date sometimes - hebamme , was the birth easy , date? - vater + info - mutter + info - location ; instead date of the uxor! - birth date - baptism date - priester + info - """ - -class weddingRecord(recordClass): - sName = 'weddingrecord' - def __init__(self,sModelName,sModelDir): - recordClass.__init__(self,deathRecord.sName) - - myTagger = ABPTagger() - myTagger.loadResources(sModelName ,sModelDir ) - - #bride - bfnField = firstNameField() - bfnField.setLabelMapping( ['firstNameGenerator']) - bfnField.addTagger(myTagger) - bfnField.setMandatory() - self.addField(bfnField) - - gnfield = lastNameField() - gnfield.addTagger(myTagger) - gnfield.setLabelMapping(['lastNameGenerator']) - gnfield.setMandatory() - self.addField(gnfield) - - #groom - gfnField = firstNameField() - gfnField.setLabelMapping( ['firstNameGenerator']) - gfnField.addTagger(myTagger) - gfnField.setMandatory() - self.addField(gfnField) - - gnfield = lastNameField() - gnfield.addTagger(myTagger) - gnfield.setLabelMapping(['lastNameGenerator']) - gnfield.setMandatory() - self.addField(gnfield) - - lfield= locationField() - lfield.addTagger(myTagger) - lfield.setLabelMapping(['locationGenerator']) - self.addField(lfield) - - wDate= weddingDate() - wDate.addTagger(myTagger) -# dDate.setLabelMapping(['weekDayDateGenerator','MonthDayDateGenerator','MonthDateGenerator']) - xDate.setLabelMapping(['MonthDateGenerator']) - self.addField(dDate) - - - - class deathRecord(recordClass): sName = 'deathrecord' def __init__(self,sModelName,sModelDir): @@ -130,7 +55,7 @@ def __init__(self,sModelName,sModelDir): lfield= locationField() lfield.addTagger(myTagger) - lfield.setLabelMapping(['location2Generator']) + lfield.setLabelMapping(['locationGenerator']) self.addField(lfield) ofield= occupationField() @@ -142,45 +67,28 @@ def __init__(self,sModelName,sModelDir): sfield.addTagger(myTagger) sfield.setLabelMapping(['familyStatus']) self.addField(sfield) -# - - # specific tagger for dates ? +# dDate= deathDate() dDate.addTagger(myTagger) # dDate.setLabelMapping(['weekDayDateGenerator','MonthDayDateGenerator','MonthDateGenerator']) dDate.setLabelMapping(['MonthDateGenerator']) self.addField(dDate) - ddDate= deathDateDay() - ddDate.addTagger(myTagger) -# dDate.setLabelMapping(['weekDayDateGenerator','MonthDayDateGenerator','MonthDateGenerator']) - ddDate.setLabelMapping(['MonthDayDateGenerator']) - self.addField(ddDate) - bDate= burialDate() bDate.addTagger(myTagger) # bDate.setLabelMapping(['weekDayDateGenerator','MonthDayDateGenerator','MonthDateGenerator']) bDate.setLabelMapping(['MonthDateGenerator']) self.addField(bDate) - year=deathYear() - year.addTagger(myTagger) - year.setLabelMapping(['yearGenerator']) - self.addField(year) agefield=age() agefield.addTagger(myTagger) agefield.setLabelMapping(['ageValueGenerator']) self.addField(agefield) - - ageUnitfield=ageUnit() - ageUnitfield.addTagger(myTagger) - ageUnitfield.setLabelMapping(['AgeUnitGenerator']) - self.addField(ageUnitfield) blfield= burialLocation() blfield.addTagger(myTagger) - blfield.setLabelMapping(['location2Generator']) + blfield.setLabelMapping(['locationGenerator']) self.addField(blfield) reasonField = deathreasonField() @@ -193,29 +101,6 @@ def __init__(self,sModelName,sModelDir): drField.setLabelMapping(['lastNameGenerator']) #lastNameGenerator self.addField(drField) -# def decoratePageXml(self): -# """ -# ONGOING.... -# add in @custom the field name -# -# -# -# currenlty -# """ -# lPages={} -# for cand in self.getCandidates(): -# try:lPages[cand.getPage()].append(cand) -# except:lPages[cand.getPage()]=[cand] -# -# for page in sorted(lPages): -# sortedRows = lPages[page] -# sortedRows.sort(key=lambda x:int(x.getIndex())) -# for cand in sortedRows: -# for field in cand.getAllFields(): -# if field.getName() is not None and field.getBestValue() is not None: -# print (field, field.getOffset() - def generateOutput(self,outDom): """ generateOutput @@ -252,10 +137,7 @@ def generateOutput(self,outDom): key=key[2:] domp.set('pagenum',key) - ## -> page has now a year attribute (X-X) - if page.getAttribute('computedyear') is None: - page.addAttribute('computedyear','') - domp.set('years',str(page.getAttribute('computedyear'))) + domp.set('years','NA') root.append(domp) sortedRows = lPages[page] sortedRows.sort(key=lambda x:int(x.getIndex())) @@ -264,17 +146,13 @@ def generateOutput(self,outDom): record = etree.Element('RECORD') # record fields nbRecords = 0 - lSeenField=[] for field in cand.getAllFields(): - # take the best one if field.getName() is not None and field.getBestValue() is not None: - record.set(field.getName().lower(),field.getBestValue()) - lSeenField.append(field.getName().lower()) - nbRecords=1 - elif field.getName().lower() not in lSeenField:record.set(field.getName().lower(),"") + record.set(field.getName(),field.getBestValue()) + nbRecords+=1 if nbRecords > 0: domp.append(record) - domp.set('nbrecords',str(len(domp))) + domp.set('nbrecords',str(nbRecords)) return outDom @@ -287,27 +165,12 @@ class locationField(fieldClass): sName='location' def __init__(self): fieldClass.__init__(self, locationField.sName) - - -class weddingDate(fieldClass): - sName='weddingDate' - def __init__(self): - fieldClass.__init__(self, weddingDate.sName) - -class deathYear(fieldClass): - sName='deathYear' - def __init__(self): - fieldClass.__init__(self, deathYear.sName) + class deathDate(fieldClass): sName='deathDate' def __init__(self): fieldClass.__init__(self, deathDate.sName) - -class deathDateDay(fieldClass): - sName='MonthDayDateGenerator' - def __init__(self): - fieldClass.__init__(self, deathDateDay.sName) - + class burialDate(fieldClass): sName='burialDate' def __init__(self): @@ -322,12 +185,6 @@ class age(fieldClass): sName='age' def __init__(self): fieldClass.__init__(self, age.sName) - -class ageUnit(fieldClass): - sName='ageUnit' - def __init__(self): - fieldClass.__init__(self, ageUnit.sName) - class firstNameField(fieldClass): sName = 'firstname' diff --git a/usecases/ABP/src/ABPRefFromBigCVS.py b/usecases/ABP/src/ABPRefFromBigCVS.py index de25e9e..f35b59f 100644 --- a/usecases/ABP/src/ABPRefFromBigCVS.py +++ b/usecases/ABP/src/ABPRefFromBigCVS.py @@ -11,18 +11,7 @@ copyright NLE 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/usecases/ABP/src/ABPResourceGeneration.py b/usecases/ABP/src/ABPResourceGeneration.py index bc43c40..438595a 100644 --- a/usecases/ABP/src/ABPResourceGeneration.py +++ b/usecases/ABP/src/ABPResourceGeneration.py @@ -10,18 +10,7 @@ copyright Naverlabs 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/usecases/ABP/src/ABPWorkflow.py b/usecases/ABP/src/ABPWorkflow.py index d2ae22a..57348ff 100644 --- a/usecases/ABP/src/ABPWorkflow.py +++ b/usecases/ABP/src/ABPWorkflow.py @@ -12,18 +12,7 @@ copyright Naver LAbs Europe 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding diff --git a/usecases/ABP/src/ABP_IE.py b/usecases/ABP/src/ABP_IE.py index 77034aa..8c9a654 100644 --- a/usecases/ABP/src/ABP_IE.py +++ b/usecases/ABP/src/ABP_IE.py @@ -11,18 +11,7 @@ READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . + Developed for the EU project READ. The READ project has received funding @@ -134,36 +123,6 @@ def labelTable(self,table): - def findNameColumn(self,table): - """ - find the column which corresponds to the people names c - """ - self.bDebug=False - #tag fields with template - lColPos = {} - lColInvName = {} - for cell in table.getCells(): - try:lColPos[cell.getIndex()[1]] - except: lColPos[cell.getIndex()[1]]=[] - if cell.getIndex()[1] < 5: - res = field.applyTaggers(cell) - for field in cell.getFields(): - if field is not None: - # res [ (token,label,score) ...] - extractedValues = field.extractLabel(res) - if extractedValues != []: - # extractedValues = map(lambda offset,value,label,score:(value,score),extractedValues) - extractedValues = list(map(lambda x:(x[1],x[3]),extractedValues)) - field.setOffset(res[0]) - field.setValue(extractedValues) - # field.addValue(extractedValues) - lColPos[cell.getIndex()[1]].append(field.getName()) - try:lColInvName[field.getName()].append(cell.getIndex()[1]) - except: lColInvName[field.getName()] = [cell.getIndex()[1]] - if self.bDebug: print ('foundXX:',field.getName(), field.getValue()) - cell.resetField() - return max(lColInvName['firstname'],key=lColInvName['firstname'].count) - def extractData(self,table,myRecord, lTemplate): """ layout @@ -178,28 +137,26 @@ def extractData(self,table,myRecord, lTemplate): find layout level for record completion extract data/record -inference if IEOnto + + """ -# self.bDebug = False -# table.buildNDARRAY() + self.bDebug = False + table.buildNDARRAY() if lTemplate is not None: # convert string to tableTemplateObject template = tableTemplateClass() template.buildFromPattern(lTemplate) template.labelTable(table) else: return None -# firstNameColIndex =self.findNameColumn(table) - - # create a batch for the full page #tag fields with template for cell in table.getCells(): if cell.getFields() != []: if self.bDebug:print(table.getPage(),cell.getIndex(), cell.getFields(), cell.getContent()) - res = myRecord.applyTaggers(cell) for field in cell.getFields(): if field is not None: - #res = field.applyTaggers(cell) + res = field.applyTaggers(cell) # res [ (token,label,score) ...] extractedValues = field.extractLabel(res) if extractedValues != []: @@ -207,8 +164,8 @@ def extractData(self,table,myRecord, lTemplate): extractedValues = list(map(lambda x:(x[1],x[3]),extractedValues)) field.setOffset(res[0]) field.setValue(extractedValues) -# field.addValue(extractedValues) if self.bDebug: print ('found:',field, field.getValue()) + ### now at record level ? ### scope = propagation using only docObject (hardcoded ?) @@ -277,6 +234,19 @@ def testGTText(self,page): + def htrWithTemplate(self,table,template,htrModelId): + """ + perform an HTR with dictionaries specific to each column + + need: docid, pageid + """ + + # for the current column: need to get tablecells ids + # more efficient(?why more efficient?) to have it at column level: not cell ; so just after table template tool + for col in table.getColumns(): + lCellsID = map(lambda x:x.getID(),col.getCells()) + for id in lCellsID: print(id) + def mineTable(self,tabel,dr): @@ -298,66 +268,46 @@ def processWithTemplate(self,table,dr): """ # selection of the dictionaries per columns # template 5,10: first col = numbering + lTemplateIE2 = [ + ((slice(1,None),slice(0,1)) ,[ 'numbering'],[ dr.getFieldByName('numbering') ]) + , ((slice(1,None),slice(1,2)) ,[ 'abp_names', 'names_aux','numbering','religion'],[ dr.getFieldByName('lastname'), dr.getFieldByName('firstname'),dr.getFieldByName('religion') ]) + , ((slice(1,None),slice(2,3)) ,[ 'abp_profession','religion' ] ,[ dr.getFieldByName('occupation'), dr.getFieldByName('religion') ]) + , ((slice(1,None),slice(3,4)) ,[ 'abp_location' ] ,[ dr.getFieldByName('location') ]) + , ((slice(1,None),slice(4,5)) ,[ 'abp_family' ] ,[ dr.getFieldByName('situation') ]) + ,((slice(1,None),slice(5,6)) ,[ 'deathreason','artz'] ,[ dr.getFieldByName('deathreason'),dr.getFieldByName('doktor')]) + , ((slice(1,None),slice(6,7)) ,[] , [ ]) #binding + , ((slice(1,None),slice(7,8)) ,[ 'abp_dates' ] ,[ dr.getFieldByName('deathDate') ]) + , ((slice(1,None),slice(8,9)) ,[ 'abp_dates','abp_location' ] ,[ dr.getFieldByName('burialDate'),dr.getFieldByName('burialLocation') ]) + , ((slice(1,None),slice(9,10)) ,[ 'abp_age'] ,[ dr.getFieldByName('age')]) +# , ((slice(1,None),slice(9,10)) ,[ dr.getFieldByName('priester')]) +# , ((slice(1,None),slice(10,11)),[ dr.getFieldByName('notes')]) + ] - # find calibration column: abp_names - table.buildNDARRAY() -# print (self.findNameColumn(table)) -# lTemplateIE2 = [ -# ((slice(1,None),slice(0,1)) ,[ 'numbering'],[ dr.getFieldByName('numbering') ]) -# , ((slice(1,None),slice(1,2)) ,[ 'abp_names', 'names_aux','numbering','religion'],[ dr.getFieldByName('lastname'), dr.getFieldByName('firstname'),dr.getFieldByName('religion') ]) -# , ((slice(1,None),slice(2,3)) ,[ 'abp_profession','religion' ] ,[ dr.getFieldByName('occupation'), dr.getFieldByName('religion') ]) -# , ((slice(1,None),slice(3,4)) ,[ 'abp_location' ] ,[ dr.getFieldByName('location') ]) -# , ((slice(1,None),slice(4,5)) ,[ 'abp_family' ] ,[ dr.getFieldByName('situation') ]) -# ,((slice(1,None),slice(5,6)) ,[ 'deathreason','artz'] ,[ dr.getFieldByName('deathreason'),dr.getFieldByName('doktor')]) -# , ((slice(1,None),slice(6,7)) ,[] , [ ]) #binding -# , ((slice(1,None),slice(7,8)) ,['abp_dates', 'abp_dates' ,'abp_year'] ,[,dr.getFieldByName('deathDate'),dr.getFieldByName('deathYear') ]) -# , ((slice(1,None),slice(8,9)) ,[ 'abp_dates','abp_location' ] ,[ dr.getFieldByName('burialDate'),dr.getFieldByName('burialLocation') ]) -# , ((slice(1,None),slice(9,10)) ,[ 'abp_age','abp_ageunit'] ,[ dr.getFieldByName('age'), dr.getFieldByName('ageUnit')]) -# # , ((slice(1,None),slice(9,10)) ,[ dr.getFieldByName('priester')]) -# # , ((slice(1,None),slice(10,11)),[ dr.getFieldByName('notes')]) -# ] - - - - - #fuzzy - lTemplateIECAL = [ - ((slice(1,None),slice(0,4)) ,[ 'abp_names', 'names_aux','numbering','religion'],[ dr.getFieldByName('lastname'), dr.getFieldByName('firstname') ,dr.getFieldByName('religion')]) - , ((slice(1,None),slice(1,4)) ,[ 'abp_profession','religion' ] ,[ dr.getFieldByName('occupation'), dr.getFieldByName('religion') ]) - ] - - #detect empty left columns ? - template = tableTemplateClass() - template.buildFromPattern(lTemplateIECAL) - template.labelTable(table) - - iRef = self.findNameColumn(table) - print ("=============",iRef) lTemplateIE = [ - ((slice(1,None),slice(iRef,iRef+1)) ,[ 'abp_names', 'names_aux','numbering','religion'],[ dr.getFieldByName('lastname'), dr.getFieldByName('firstname') ,dr.getFieldByName('religion')]) - , ((slice(1,None),slice(iRef+1,iRef+2)) ,[ 'abp_profession','religion' ] ,[ dr.getFieldByName('occupation'), dr.getFieldByName('religion') ]) - , ((slice(1,None),slice(iRef+2,iRef+3)) ,[ 'abp_location' ] ,[ dr.getFieldByName('location') ]) - , ((slice(1,None),slice(iRef+3,iRef+4)) ,[ 'abp_family' ] ,[ dr.getFieldByName('situation') ]) + ((slice(1,None),slice(0,1)) ,[ 'abp_names', 'names_aux','numbering','religion'],[ dr.getFieldByName('lastname'), dr.getFieldByName('firstname') ,dr.getFieldByName('religion')]) + , ((slice(1,None),slice(1,2)) ,[ 'abp_profession','religion' ] ,[ dr.getFieldByName('occupation'), dr.getFieldByName('religion') ]) + , ((slice(1,None),slice(2,3)) ,[ 'abp_location' ] ,[ dr.getFieldByName('location') ]) + , ((slice(1,None),slice(3,4)) ,[ 'abp_family' ] ,[ dr.getFieldByName('situation') ]) #[] binding - # 4 6 - # 5 7 - # 6 8 - , ((slice(1,None),slice(iRef+4,iRef+6)) ,[ 'abp_deathreason','artz'] ,[ dr.getFieldByName('deathreason'),dr.getFieldByName('doktor')]) - , ((slice(1,None),slice(iRef+5,iRef+9)) ,[ 'abp_dates','abp_year' ] ,[ dr.getFieldByName('MonthDayDateGenerator'), dr.getFieldByName('deathDate') ,dr.getFieldByName('deathYear')]) - , ((slice(1,None),slice(iRef+6,iRef+9)) ,[ 'abp_dates','abp_year','abp_location' ] ,[ dr.getFieldByName('burialDate'),dr.getFieldByName('deathYear'),dr.getFieldByName('burialLocation') ]) - , ((slice(1,None),slice(iRef+8,iRef+10)) ,[ 'abp_age','abp_ageunit'] ,[ dr.getFieldByName('age'), dr.getFieldByName('ageUnit')]) + , ((slice(1,None),slice(4,6)) ,[ 'deathreason','artz'] ,[ dr.getFieldByName('deathreason'),dr.getFieldByName('doktor')]) + , ((slice(1,None),slice(6,7)) ,[ 'abp_dates' ] ,[ dr.getFieldByName('deathDate') ]) + , ((slice(1,None),slice(7,8)) ,[ 'abp_dates','abp_location' ] ,[ dr.getFieldByName('burialDate'),dr.getFieldByName('burialLocation') ]) + , ((slice(1,None),slice(8,9)) ,[ 'abp_age'] ,[ dr.getFieldByName('age')]) # , ((slice(1,None),slice(9,10)) ,[ dr.getFieldByName('priester')]) # , ((slice(1,None),slice(10,11)),[ dr.getFieldByName('notes')]) - ] - # recalibrate template + ] -# # lTemplate = lTemplateIE -# if table.getNbColumns() >= 12: -# lTemplate = lTemplateIE2 -# else: -# lTemplate = lTemplateIE - self.extractData(table,dr,lTemplateIE) + +# lTemplate = lTemplateIE + if table.getNbColumns() == 12: + lTemplate = lTemplateIE2 + else: + lTemplate = lTemplateIE + +# if self.htrModelID is not None: self.htrWithTemplate(table, lTemplate, self.htrModelID) + + self.extractData(table,dr,lTemplate) # select best solutions # store inthe proper final format @@ -390,7 +340,7 @@ def run(self,doc): ### for page in self.lPages: - print("page: ", page.getNumber()) +# print("page: ", page.getNumber()) # self.testGTText(page) # continue lTables = page.getAllNamedObjects(XMLDSTABLEClass) @@ -401,11 +351,9 @@ def run(self,doc): continue if self.BuseStoredTemplate: self.processWithTemplate(table, dr) - #try:self.processWithTemplate(table, dr) - #except: print('issue with page %s'%page) else: self.mineTable(table,dr) - + self.evalData = dr.generateOutput(self.evalData) # print self.evalData.serialize('utf-8',True) @@ -536,7 +484,7 @@ def testFirstNameLastNameRecord(self,srefData,srunData, bVisual): lCovered=[] for a,i in enumerate(r2): # print (key,a,r1[a],i,rows[r1[a]][2],cols[i][2], 1/cost_matrix[r1[a],i]) - if 1 / cost_matrix[r1[a],i] > lcsTH: + if 1 / cost_matrix[r1[a,],i] > lcsTH: cntOk += 1 if bT: ltisRefsRunbErrbMiss.append( (runElt[1],int(runElt[0]), cols[i], rows[r1[a]],False, False) ) @@ -622,8 +570,8 @@ def testRecordField(self,lfieldName,lfieldInRef,srefData,srunData, bVisual): key=page.get('pagenum') xpath = "./%s" % ("RECORD") lrecord = page.xpath(xpath) - if len(lrecord)==0: - lRef.append([]) + if len(lrecord) == 0: + pass else: for record in lrecord: lf =[] @@ -835,8 +783,6 @@ def testCompare(self, srefData, srunData, bVisual=False): dicTestByTask['location']= self.testRecordField(['location'],[None],srefData, srunData,bVisual) dicTestByTask['deathreason']= self.testRecordField(['deathreason'],[None],srefData, srunData,bVisual) dicTestByTask['names']= self.testRecordField(['firstname','lastname'],[None,None],srefData, srunData,bVisual) - dicTestByTask['doktor']= self.testRecordField(['doktor'],['helfer_name'],srefData, srunData,bVisual) - # dicTestByTask['namedeathlocationoccupation']= self.testRecordField(['firstname','lastname','deathreason','location','occupation'],[None,None,None,None,None],srefData, srunData,bVisual) dicTestByTask['situation']= self.testRecordField(['situation'],['family'],srefData, srunData,bVisual) # dicTestByTask['Year']= self.testYear(srefData, srunData,bVisual) diff --git a/usecases/ABP/src/PageCellToRegion.py b/usecases/ABP/src/PageCellToRegion.py index deabdaa..f622ace 100644 --- a/usecases/ABP/src/PageCellToRegion.py +++ b/usecases/ABP/src/PageCellToRegion.py @@ -41,8 +41,8 @@ def __init__(self): self.sPttrn = None self.dpi = 300 -# self.xmlns='http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15' - self.ns={'a':PageXml.NS_PAGE_XML} + self.xmlns='http://schema.primaresearch.org/PAGE/gts/pagecontent/2013-07-15' + self.ns={'a':self.xmlns} self.id=1 self.HeightTH=0.5 @@ -63,7 +63,7 @@ def resizeCell(self,cell,ns): replace the cell region by a BB for textlines: better for transcriber """ xpath = "./a:%s" % ("TextLine") - lTextLines = cell.xpath(xpath, namespaces={'a':PageXml.NS_PAGE_XML}) + lTextLines = cell.xpath(xpath, namespaces={'a':self.xmlns}) if lTextLines == []: return True @@ -149,16 +149,16 @@ def convertTableCells(self,document): xpath = "//a:%s" % ("ReadingOrder") lRO = document.getroot().xpath(xpath,namespaces = self.ns) if lRO == []: - ro = PageXml.createPageXmlNode('ReadingOrder') + ro = PageXml.createPageXmlNode('ReadingOrder', self.xmlns) #addPrevSibling else: ro =lRO[0] for table in lTables: - orderGroup = PageXml.createPageXmlNode('OrderedGroup') + orderGroup = PageXml.createPageXmlNode('OrderedGroup',self.xmlns) ro.append(orderGroup) - orderGroup.set('{%s}id'%PageXml.NS_PAGE_XML,table.get('id')) - orderGroup.set('{%s}caption'%PageXml.NS_PAGE_XML,'Cell2TextRegion') + orderGroup.set('{%s}id'%self.xmlns,table.get('id')) + orderGroup.set('{%s}caption'%self.xmlns,'Cell2TextRegion') xpath = "./a:%s" % ("TableCell") lCells = table.xpath(xpath,namespaces = self.ns) @@ -169,7 +169,7 @@ def convertTableCells(self,document): # cell.unlinkNode() # print cell table.getparent().append(cell) - cell.tag = '{%s}TextRegion'%(PageXml.NS_PAGE_XML) + cell.tag = '{%s}TextRegion'%(self.xmlns) cell.set('custom',"readingOrder {index:%d;}"%i) # delete cell props for propname in ['row','col','rowSpan','colSpan']: @@ -181,10 +181,10 @@ def convertTableCells(self,document): lCorner = cell.xpath(xpath,namespaces = self.ns) for c in lCorner: c.getparent().remove(c) - reind = PageXml.createPageXmlNode('RegionRefIndexed') + reind = PageXml.createPageXmlNode('RegionRefIndexed', self.xmlns) orderGroup.append(reind) - reind.set('{%s}index'%PageXml.NS_PAGE_XML,str(i)) - reind.set('{%s}regionRef'%PageXml.NS_PAGE_XML,cell.get('id')) + reind.set('{%s}index'%self.xmlns,str(i)) + reind.set('{%s}regionRef'%self.xmlns,cell.get('id')) ## resize cell/region: if self.resizeCell(cell,self.ns): diff --git a/usecases/ABP/src/contentGenerator.py b/usecases/ABP/src/contentGenerator.py index 78ccabe..67a4a4a 100644 --- a/usecases/ABP/src/contentGenerator.py +++ b/usecases/ABP/src/contentGenerator.py @@ -11,18 +11,7 @@ copyright NLE 2017 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . Developed for the EU project READ. The READ project has received funding @@ -68,7 +57,7 @@ def generate(self): class AgeUnitGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) - self.loadResourcesFromList( [[('Jahre',50),('Ja',10),('Monate',10),('M.',10),('W',5),('Wochen',5),('T',3),('Tag',6),('Stunde',10)]]) + self.loadResourcesFromList( [[('Jahre',50),('Ja',10),('Monate',20),('Wochen',10),('Tag',10),('Stunde',10)]]) class ageValueGenerator(integerGenerator): """ @@ -108,44 +97,18 @@ def generate(self): return Generator.generate(self) class legitimGenerator(textGenerator): - """ - ID name namelabel kuerzel - 1_1 legitim leg. l - 1_2 illegitim ill. i - 1_3 adoptiert adopt. a - 1_4 durch nachfolge p.m.s.l. vor - - """ def __init__(self): textGenerator.__init__(self,lang=None) # self._value = ['leg','legitim','illeg','illegitim'] self.loadResourcesFromList( [[('leg',60),('legitim',20),('illeg',10),('illegitim',20)]]) class religionGenerator(textGenerator): - """ - 2_1 katholisch kath. rk -2_2 evangelisch ev. ev -2_3 orthodox orth. or -2_4 sonstige sonst. ss -2_5 altkatholisch altkath. alt -2_6 christlich christlich ch -2_7 Konvertit Konvertit kon -2_8 protestantisch prot. pr - - """ def __init__(self): textGenerator.__init__(self,lang=None) self.loadResourcesFromList( [[('K',30),('kath',40),('katholic',5),('katho',5),('K. R.',5),("evangelist",5),('evang.',5),("evg.",5)]]) # self._value = ['k','kath','katholic','katho','k. R.','evangelist','evang.','evg.'] class familyStatus(textGenerator): - """ - 3_1 ledig ledig ld -3_2 verheiratet verh. vh -3_3 verwitwet verw. vw - - children not covered - """ def __init__(self): textGenerator.__init__(self,lang=None) self.loadResourcesFromList( [[('knabe',5),('mädchen',5),('kind',30),('Säugling',5),('ledig',20), ('verehelichet.',10),('erehelicht',10),('wittwe',20),('wittwer',10),('verwitwet',5),('verw.',5),('verheirathet',10),('verhei',10)]]) @@ -166,7 +129,7 @@ class deathreasonGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'deathreason' - self._lpath=[os.path.abspath('../resources/old/deathreason.pkl')] + self._lpath=[os.path.abspath('../resources/deathreason.pkl')] self._value = list(map(lambda x:x[0],self.loadResources(self._lpath))) self._lenRes= len(self._lresources) @@ -182,22 +145,7 @@ def __init__(self): two locations """ -class NGenerator(textGenerator): - def __init__(self): - textGenerator.__init__(self,lang=None) - self.loadResourcesFromList( [[('N',50),('Num',10)]]) - -class HausnumberGenerator(textGenerator): - def __init__(self,mean,std): - textGenerator.__init__(self,lang=None) - self._structure = [ ( (NGenerator(),1,80) ,(positiveIntegerGenerator(mean,std),1,100),100 ) ] - def generate(self): - return Generator.generate(self) - class location2Generator(textGenerator): - """ - missing Rothsmansdorf Nr̳ 12 - """ def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'location' @@ -205,7 +153,7 @@ def __init__(self): self.location2 = locationGenerator() self.prep = locationPrepositionGenerator() self._structure = [ - ( (self.location2,1,20),(self.prep,1,10), (self.location,1,100),(HausnumberGenerator(50,10),1,20),(legitimGenerator(),1,10),100) + ( (self.location2,1,20),(self.prep,1,10), (self.location,1,100),(legitimGenerator(),1,10),100) ] def generate(self): return Generator.generate(self) @@ -214,7 +162,7 @@ class locationGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'location' - self._lpath=[os.path.abspath('../resources/old/location.pkl')] + self._lpath=[os.path.abspath('../resources/location.pkl')] self._value = list(map(lambda x:x[0],self.loadResources(self._lpath))) self._lenRes= len(self._lresources) @@ -232,7 +180,7 @@ class professionGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'profession' - self._lpath=[os.path.abspath('../resources/old/profession.pkl')] + self._lpath=[os.path.abspath('../resources/profession.pkl')] self._value = list(map(lambda x:x[0],self.loadResources(self._lpath))) self._lenRes= len(self._lresources) @@ -241,7 +189,7 @@ class firstNameGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'firstName' - self._lpath=[os.path.abspath('../resources/old/firstname.pkl')] + self._lpath=[os.path.abspath('../resources/firstname.pkl')] self._value = list(map(lambda x:x[0],self.loadResources(self._lpath))) self._lenRes= len(self._lresources) @@ -250,7 +198,7 @@ class lastNameGenerator(textGenerator): def __init__(self): textGenerator.__init__(self,lang=None) self._name = 'firstName' - self._lpath=[os.path.abspath('../resources/old/lastname.pkl')] + self._lpath=[os.path.abspath('../resources/lastname.pkl')] self._value = list(map(lambda x:x[0],self.loadResources(self._lpath))) self._lenRes= len(self._lresources) @@ -344,12 +292,6 @@ def generate(self): return self class MonthDayDateGenerator(textGenerator): - """ - '16. Nov' -> [((0, 0), '16.', 'numberedItems', [0.9996762]), ((1, 1), 'Nov', 'MonthDateGenerator', [0.9997758])] - - - add . after number ? - """ def __init__(self,lang,value=None): textGenerator.__init__(self,lang) self._value = [value] @@ -392,7 +334,7 @@ class HourDateGenerator(textGenerator): def __init__(self,lang,value=None): self._fulldate = None textGenerator.__init__(self,lang) - #self._value = [value] + self._value = [value] self.realization=['H','I'] def setValue(self,d): @@ -400,8 +342,8 @@ def setValue(self,d): self._value = [d.hour] def generate(self): - try:self._generation = u""+str(int(self._fulldate.strftime('%'+ '%s'%self.getRandomElt(self.realization)))) - except UnicodeDecodeError: self._generation = u""+self._fulldate.strftime('%'+ '%d'%self.getRandomElt(self.realization)).decode('latin-1') + try:self._generation = u""+self._fulldate.strftime('%'+ '%s'%self.getRandomElt(self.realization)) + except UnicodeDecodeError: self._generation = u""+self._fulldate.strftime('%'+ '%s'%self.getRandomElt(self.realization)).decode('latin-1') return self class yearGenerator(textGenerator): @@ -429,7 +371,7 @@ def generate(self): class DayPartsGenerator(textGenerator): def __init__(self,lang,value=None): textGenerator.__init__(self,lang) - self._value=['abends','morgens','vormittags','nachmittags','mittags','nacht','fruh','früh'] + self._value=['abends','morgens','nachmittags','mittags','nacht','fruh'] class FullHourDateGenerator(textGenerator): @@ -458,7 +400,7 @@ def __init__(self,lang): self.hourGen = FullHourDateGenerator(lang) self.yearGen = yearGenerator(lang) self._structure = [ - ((self.yearGen,1,90),(self.weekdayGen,1,90),(self.monthdayGen,1,90),(self.monthGen,1,90),(self.hourGen,1,100), 100) + ((self.yearGen,1,90),(self.weekdayGen,1,90),(self.monthdayGen,1,90),(self.monthGen,1,90),(self.hourGen,1,100), 75) ] def setValue(self,v): """ @@ -513,11 +455,8 @@ def __init__(self): self._structure = [ - ( (self.monthdayGen,1,90),(self.monthGen,1,100), 100) - , ( (self.weekdayGen,1,90),(self.monthdayGen,1,90),(self.monthGen,1,90),(self.yearGen,1,40),(self.hourGen,1,100), 100) - , ( (DENGenerator(self.lang),1,100),(self.monthdayGen,1,100),(self.monthGen,1,90), (self.hourGen,1,10) ,100) - # ?? - ,( (self.yearGen,1,100),50) + ( (self.weekdayGen,1,90),(self.monthdayGen,1,90),(self.monthGen,1,90),(self.yearGen,1,40),(self.hourGen,1,100), 100) + ,( (DENGenerator(self.lang),1,100),(self.monthdayGen,1,100),(self.monthGen,1,90), (self.hourGen,1,10) ,100) ] @@ -550,46 +489,6 @@ class ABPRecordGenerator(textGenerator): else: lang='de-DE' - # per type as wel!! - lClassesToBeLearnt = [[],[]] - lClassesToBeLearnt[1] = [ - 'deathreasonGenerator' - ,'doktorGenerator' - ,'legitemGenerator' - ,'doktorTitleGenerator' - ,'lastNameGenerator' - ,'firstNameGenerator' - ,'professionGenerator' - ,'religionGenerator' - ,'familyStatus' - ,'textletterRandomGenerator' - ,'numberedItems' - ,'location2Generator' - ,'ageValueGenerator' - ,'AgeUnitGenerator' - ,'DENGeneratornum' - ,'MonthDayDateGenerator' - ,'weekDayDateGenerator' - ,'MonthDateGenerator' - ,'UMGenerator' - ,'HourDateGenerator' - ,'UHRGenerator' - ,'yearGenerator' - ,'numericalGenerator' - ,'textRandomGenerator' - ,'integerGenerator' - ,'textletterRandomGenerator' - ,'legitimGenerator' - ] - - lClassesToBeLearnt[0]= [ - 'deathreasonGenerator' - ,'doktorGenerator' - ,'PersonName2' - ,'AgeGenerator' - ,'ABPGermanDateGenerator' - ] - # method level otherwise loadresources for each sample!! person= PersonName2(lang) date= ABPGermanDateGenerator() @@ -610,6 +509,31 @@ class ABPRecordGenerator(textGenerator): noise2 = textletterRandomGenerator(10,5) + # per type as wel!! + lClassesToBeLearnt =['deathreasonGenerator' + ,'doktorGenerator' + ,'doktorTitleGenerator' + ,'PersonName2' + ,'lastNameGenerator' + ,'firstNameGenerator' + ,'professionGenerator' + ,'religionGenerator' + ,'familyStatus' + ,'textletterRandomGenerator' + ,'locationGenerator' + ,'AgeGenerator' + ,'ageValueGenerator' + ,'AgeUnitGenerator' + ,'ABPGermanDateGenerator' + ,'DENGeneratornum' + ,'MonthDayDateGenerator' + ,'weekDayDateGenerator' + ,'MonthDateGenerator' + ,'UMGenerator' + ,'HourDateGenerator' + ,'UHRGenerator' + ,'yearGenerator' + ] def __init__(self): textGenerator.__init__(self,self.lang) @@ -620,8 +544,7 @@ def __init__(self): self.noise2,self.person, self.date,self.deathreasons,self.doktor,self.location,self.profession,self.status, self.age, self.misc] -# myList=[self.person] - for g in myList: g.setClassesToBeLearnt(self.lClassesToBeLearnt) + self._structure = [] @@ -674,7 +597,7 @@ class ABPRecordGeneratorTOK(textGenerator): # method level otherwise loadresources for each sample!! person= PersonName2(lang) date= ABPGermanDateGenerator() - date.defineRange(1700, 1900) + date.defineRange(1700, 2000) deathreasons = deathReasonColumnGenerator(lang) doktor= doktorGenerator(lang) location= location2Generator() @@ -732,55 +655,37 @@ def ABP(options,args): g.GTForTokenization() else: if options.bLoad: - with gzip.open(os.path.join(options.dirname,options.name+".pkl"), "rb") as fd: - g = pickle.load(fd) - print('generator loaded:%s'%(os.path.join(options.dirname,options.name+".pkl"))) - print (g.__class__.__name__) - print (g.getNoiseLevel()) - else: - g = ABPRecordGenerator() - g.setNoiseType(options.noiseType) - g.setNoiseLevel(options.noiseLevel) - - if options.bconll: - lReport={} - - lvlrange = [0,10] - lfd=[None for i in range(len(lvlrange))] - for i,lvl in enumerate(lvlrange): - lfd[i] = open(os.path.join(options.dirname,options.name+"_%s_%s.txt"%(lvl,g.getNoiseType())), "w",encoding='utf-8') - + pass + + g = ABPRecordGenerator() + g.setNoiseType(options.noiseType) + lReport={} + fd= open(os.path.join(options.dirname,options.name+".txt"), "w",encoding='utf-8') for i in range(options.nbX): g.instantiate() - # store the history? g.generate() try:lReport[tuple(g._instance)] +=1 except KeyError: lReport[tuple(g._instance)] = 1 - if options.bFairseq: sS,sT =g.formatFairSeqWord(g.exportAnnotatedData([])) if len(sS.strip()) > 0: iosource.write("%s\n"%sS) iotarget.write("%s\n"%sT) - - elif options.bconll: - for i,lvl in enumerate(lvlrange): - g.setNoiseLevel(lvl) - sGen = g.formatAnnotatedData(g.exportAnnotatedData([ "None","None" ,"None"]),mode=2) - lfd[i].write(sGen) - - if options.bconll: - [lfd[i].write("# %s %s\n"%(lReport[inst],inst)) for i in range(len(lvlrange)) for inst in lReport] - [fd.close() for fd in lfd] + else: + sGen = g.formatAnnotatedData(g.exportAnnotatedData([]),mode=2) + fd.write(sGen) + for inst in lReport: + fd.write("# %s %s\n"%(lReport[inst],inst)) + fd.close() if options.bFairseq: iosource.close() iotarget.close() -# elif options.bconll: -# if g is not None and not options.bLoad: -# with gzip.open(os.path.join(options.dirname,options.name+".pkl"), "wb") as fd: -# pickle.dump(g, fd, protocol=2) + elif options.bconll: + if g is not None: + with gzip.open(os.path.join(options.dirname,options.name+".pkl"), "wb") as fd: + pickle.dump(g, fd, protocol=2) if __name__ == "__main__": @@ -794,13 +699,11 @@ def ABP(options,args): parser.add_option("--model", dest="name", action="store", type="string",default="test.pkl", help="model name") parser.add_option("--dir", dest="dirname", action="store", type="string", default=".",help="directory to store model") parser.add_option("--noise", dest="noiseType", action="store", type=int, default=0, help="add noise of type N") - parser.add_option("--noiselvl", dest="noiseLevel", action="store", type=int, default=10, help="noise level (percentage) NN") - - parser.add_option("--load", dest="bLoad", action="store_true", default=False, help="load model") + parser.add_option("--load", dest="bLoad", action="store_true", default=False, help="model name") parser.add_option("--number", dest="nbX", action="store", type=int, default=10,help="number of samples") - parser.add_option("--tok", dest="bTok", action="store", type=int,default=False, help="correct tokenisation GT") + parser.add_option("--tok", dest="bTok", action="store", type=int,default=False, help="correct tokination GT") parser.add_option("--fairseq", dest="bFairseq", action="store", type=int, default=False,help="seq2seq GT") - parser.add_option("--conll", dest="bconll", action="store_true", default=True,help="conll like GT") + parser.add_option("--conll", dest="bconll", action="store", type=int, default=True,help="conll like GT") (options, args) = parser.parse_args() diff --git a/usecases/ABP/src/processDatesPerPage.py b/usecases/ABP/src/processDatesPerPage.py index 40fdbe2..0a8ea9e 100644 --- a/usecases/ABP/src/processDatesPerPage.py +++ b/usecases/ABP/src/processDatesPerPage.py @@ -1,29 +1,12 @@ # -*- coding: utf-8 -*- """ - - Build a table grid from cells H. Déjean - copyright Naver 2019 READ project - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - - You should have received a copy of the GNU General Public License - along with this program. If not, see . - - Developed for the EU project READ. The READ project has received funding from the European Union's Horizon 2020 research and innovation programme under grant agreement No 674943. diff --git a/usecases/BAR/mpxml_viewer.bat b/usecases/BAR/mpxml_viewer.bat index 862bde1..0515c08 100644 --- a/usecases/BAR/mpxml_viewer.bat +++ b/usecases/BAR/mpxml_viewer.bat @@ -1 +1 @@ -C:\Anaconda2\python.exe c:\Local\meunier\git\TranskribusDU\TranskribusDU\visu\mpxml_viewer.bat.py %0.ini +C:\Anaconda\python.exe c:\Local\TranskribusDU\src\visu\mpxml_viewer.bat.py C:\Local\TranskribusDU\usecases\BAR\mpxml_viewer.bat.ini diff --git a/usecases/StAZH/DU_StAZH.py b/usecases/StAZH/DU_StAZH.py new file mode 100644 index 0000000..33bf279 --- /dev/null +++ b/usecases/StAZH/DU_StAZH.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + +""" + First DU task for StAZH + + Copyright Xerox(C) 2016 JL. Meunier + Copyright Naver (C) 2019 H. Déjean +""" +import sys, os + +import TranskribusDU_version # if import error, updade the PYTHONPATH environment variable + +from common.trace import traceln +from tasks.DU_Task_Factory import DU_Task_Factory +from graph.Graph_Multi_SinglePageXml import Graph_MultiSinglePageXml +from graph.NodeType_PageXml import NodeType_PageXml_type +from graph.FeatureDefinition_PageXml_std import FeatureDefinition_PageXml_StandardOnes +from graph.NodeType_PageXml import NodeType_PageXml_type_woText, NodeType_PageXml_type +from graph.FeatureDefinition_PageXml_std_noText import FeatureDefinition_PageXml_StandardOnes_noText + + +def getConfiguredGraphClass(doer): + """ + In this class method, we must return a configured graph class + """ + #DU_GRAPH = ConjugateSegmenterGraph_MultiSinglePageXml # consider each age as if indep from each other + DU_GRAPH = Graph_MultiSinglePageXml + + ntClass = NodeType_PageXml_type_woText #NodeType_PageXml_type + + #lIgnoredLabels = ['menu-section-heading','Item-number'] + + lLabels = ['catch-word', 'header', 'heading', 'marginalia', 'page-number'] + + nt = ntClass("TR" #some short prefix because labels below are prefixed with it + , lLabels # in conjugate, we accept all labels, andNone becomes "none" + , [] + , True # unused + , BBoxDeltaFun=lambda v: max(v * 0.066, min(5, v/3)) #we reduce overlap in this way + ) + nt.setLabelAttribute("type") + #DU_GRAPH.sEdgeLabelAttribute="TR" + nt.setXpathExpr((".//pc:TextRegion" + , ".//pc:TextEquiv") #how to get their text + ) + DU_GRAPH.addNodeType(nt) + + return DU_GRAPH + + +if __name__ == "__main__": + # import better_exceptions + # better_exceptions.MAX_LENGTH = None + + # standard command line options for CRF- ECN- GAT-based methods + usage, parser = DU_Task_Factory.getStandardOptionsParser(sys.argv[0]) + + traceln("VERSION: %s" % DU_Task_Factory.getVersion()) + + # --- + #parse the command line + (options, args) = parser.parse_args() + + cFeatureDefinition = FeatureDefinition_PageXml_StandardOnes_noText +# dFeatureConfig = { +# 'n_tfidf_node':400, 't_ngrams_node':(1,3), 'b_tfidf_node_lc':False +# , 'n_tfidf_edge':400, 't_ngrams_edge':(1,3), 'b_tfidf_edge_lc':False } + try: + sModelDir, sModelName = args + except Exception as e: + traceln("Specify a model folder and a model name!") + DU_Task_Factory.exit(usage, 1, e) + + doer = DU_Task_Factory.getDoer(sModelDir, sModelName + , options = options + , fun_getConfiguredGraphClass= getConfiguredGraphClass + , cFeatureDefinition = cFeatureDefinition + , dFeatureConfig = {} + ) + + # setting the learner configuration, in a standard way + # (from command line options, or from a JSON configuration file) + dLearnerConfig = doer.getStandardLearnerConfig(options) + + + # of course, you can put yours here instead. + doer.setLearnerConfiguration(dLearnerConfig) + + #doer.setConjugateMode() + + # act as per specified in the command line (--trn , --fold-run, ...) + doer.standardDo(options) + + del doer + + + + + + + diff --git a/usecases/__init__.py b/usecases/__init__.py new file mode 100644 index 0000000..d52e42e --- /dev/null +++ b/usecases/__init__.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +#REMOVE THIS annoying warning saying: +# /usr/lib/python2.7/site-packages/requests-2.12.1-py2.7.egg/requests/packages/urllib3/connectionpool.py:843: InsecureRequestWarning: Unverified HTTPS request is being made. +# Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning) + +import sys, os + +DEBUG=0 + +sCOL = "col" + +def _exit(usage, status, exc=None): + if usage: + sys.stderr.write("ERROR: usage : %s\n"%usage) + if exc != None: + sys.stderr.write(str(exc)) #any exception? + sys.exit(status) + +def _checkFindColDir(lsDir, sColName=sCOL, bAbsolute=True): + """ + For each directory in the input list, check if it is a "col" directory, or look for a 'col' sub-directory + If a string is given instead of a list, make of it a list + If None is given, just return an empty list + return the list of "col" directory absolute path + or raise an exception + """ + if lsDir == None: return list() + if type(lsDir) != list: lsDir = [lsDir] + lsColDir = list() + for sDir in lsDir: + if not(sDir.endswith(sColName) or sDir.endswith(sColName+os.path.sep)): + sColDir = os.path.join(sDir, sColName) + else: + sColDir = sDir + if bAbsolute: + sColDir = os.path.abspath(sColDir) + if not( os.path.exists(sColDir) and os.path.isdir(sColDir) ): + raise ValueError("Non-existing folder: %s"%sColDir) + lsColDir.append(sColDir) + return lsColDir