diff --git a/gtfsdb/model/gtfs.py b/gtfsdb/model/gtfs.py index 3a8a867..c872a33 100644 --- a/gtfsdb/model/gtfs.py +++ b/gtfsdb/model/gtfs.py @@ -1,8 +1,6 @@ from contextlib import closing import logging -import time import shutil -import sys import tempfile from urllib import urlretrieve import zipfile @@ -21,7 +19,7 @@ def __init__(self, filename): def load(self, db, **kwargs): '''Load GTFS into database''' - log.debug('begin load') + log.debug('begin GTFS.load') '''load known GTFS files, derived tables & lookup tables''' gtfs_directory = self.unzip() @@ -33,21 +31,9 @@ def load(self, db, **kwargs): cls.load(db, **load_kwargs) shutil.rmtree(gtfs_directory) - '''load derived geometries, currently only written for PostgreSQL''' - if db.is_geospatial and db.is_postgresql: - s = ' - %s geom' % (Route.__tablename__) - sys.stdout.write(s) - start_seconds = time.time() - session = db.session - q = session.query(Route) - for route in q: - route.load_geometry(session) - session.merge(route) - session.commit() - session.close() - process_time = time.time() - start_seconds - print ' (%.0f seconds)' % (process_time) - log.debug('end load') + '''load route geometries derived from shapes.txt''' + Route.load_geoms(db) + log.debug('end GTFS.load') def unzip(self, path=None): '''Unzip GTFS files from URL/directory to path.''' diff --git a/gtfsdb/model/route.py b/gtfsdb/model/route.py index 9ec6e46..53cedae 100644 --- a/gtfsdb/model/route.py +++ b/gtfsdb/model/route.py @@ -32,8 +32,7 @@ class Route(Base): route_short_name = Column(String(255)) route_long_name = Column(String(255)) route_desc = Column(String(255)) - route_type = Column(Integer, - ForeignKey('route_type.route_type'), index=True, nullable=False) + route_type = Column(Integer, index=True, nullable=False) route_url = Column(String(255)) route_color = Column(String(6)) route_text_color = Column(String(6)) @@ -42,17 +41,25 @@ class Route(Base): stop_times = relationship('StopTime', secondary='trips') trips = relationship('Trip') - def load_geometry(self, session): + @classmethod + def load_geoms(cls, db): from gtfsdb.model.shape import Pattern from gtfsdb.model.trip import Trip - if hasattr(self, 'geom'): - s = func.st_collect(Pattern.geom) - s = func.st_multi(s) - s = func.st_astext(s).label('geom') - q = session.query(s) - q = q.filter(Pattern.trips.any((Trip.route == self))) - self.geom = q.first().geom + '''load derived geometries, currently only written for PostgreSQL''' + if db.is_geospatial and db.is_postgresql: + session = db.session + routes = session.query(Route).all() + for route in routes: + s = func.st_collect(Pattern.geom) + s = func.st_multi(s) + s = func.st_astext(s).label('geom') + q = session.query(s) + q = q.filter(Pattern.trips.any((Trip.route == route))) + route.geom = q.first().geom + session.merge(route) + session.commit() + session.close() @classmethod def add_geometry_column(cls):