From 2a086d738f0d5eb17e2752c26305be09abc774d4 Mon Sep 17 00:00:00 2001 From: Mike Gilligan Date: Thu, 27 Feb 2014 16:30:48 -0800 Subject: [PATCH] Simplified GTFS.load method --- gtfsdb/api.py | 2 +- gtfsdb/model/base.py | 19 +++++++++++++++-- gtfsdb/model/calendar.py | 2 +- gtfsdb/model/gtfs.py | 46 +++++++++------------------------------- gtfsdb/model/shape.py | 2 +- 5 files changed, 30 insertions(+), 41 deletions(-) diff --git a/gtfsdb/api.py b/gtfsdb/api.py index fd13c64..6d7b1ef 100644 --- a/gtfsdb/api.py +++ b/gtfsdb/api.py @@ -17,4 +17,4 @@ def database_load(filename, **kwargs): db = Database(**kwargs) db.create() gtfs = GTFS(filename) - gtfs.load(db) + gtfs.load(db, **kwargs) diff --git a/gtfsdb/model/base.py b/gtfsdb/model/base.py index 3b35bb9..55668eb 100644 --- a/gtfsdb/model/base.py +++ b/gtfsdb/model/base.py @@ -44,7 +44,23 @@ def to_dict(self): return ret_val @classmethod - def load(cls, db, directory=None, validate=True, batch_size=10000): + def load(cls, db, **kwargs): + '''Load method for ORM + + arguments: + db: instance of gtfsdb.Database + + keyword arguments: + gtfs_directory: path to unzipped GTFS files + batch_size: batch size for memory management + ''' + batch_size = kwargs.get('batch_size', config.DEFAULT_BATCH_SIZE) + directory = None + if cls.datasource == config.DATASOURCE_GTFS: + directory = kwargs.get('gtfs_directory') + elif cls.datasource == config.DATASOURCE_LOOKUP: + directory = resource_filename('gtfsdb', 'data') + records = [] file_path = os.path.join(directory, cls.filename) if os.path.exists(file_path): @@ -60,7 +76,6 @@ def load(cls, db, directory=None, validate=True, batch_size=10000): for row in reader: records.append(cls.make_record(row)) i += 1 - # commit every `batch_size` records to manage memory if i >= batch_size: db.engine.execute(table.insert(), records) sys.stdout.write('*') diff --git a/gtfsdb/model/calendar.py b/gtfsdb/model/calendar.py index e548d02..1621e4c 100644 --- a/gtfsdb/model/calendar.py +++ b/gtfsdb/model/calendar.py @@ -108,7 +108,7 @@ def from_calendar_date(cls, calendar_date): return cls(**kwargs) @classmethod - def load(cls, db): + def load(cls, db, **kwargs): start_time = time.time() session = db.session q = session.query(Calendar) diff --git a/gtfsdb/model/gtfs.py b/gtfsdb/model/gtfs.py index 19599c7..3a8a867 100644 --- a/gtfsdb/model/gtfs.py +++ b/gtfsdb/model/gtfs.py @@ -7,19 +7,8 @@ from urllib import urlretrieve import zipfile -from .agency import Agency -from .calendar import Calendar, CalendarDate, UniversalCalendar -from .fare import FareAttribute, FareRule -from .feed_info import FeedInfo -from .frequency import Frequency -from .route import Route, RouteType -from .shape import Pattern, Shape -from .stop_time import StopTime -from .stop import Stop -from .stop_feature import StopFeature, StopFeatureType -from .transfer import Transfer -from .trip import Trip - +from gtfsdb import config +from .route import Route log = logging.getLogger(__name__) @@ -30,34 +19,19 @@ def __init__(self, filename): self.file = filename self.local_file = urlretrieve(filename)[0] - def load(self, db): + def load(self, db, **kwargs): '''Load GTFS into database''' log.debug('begin load') - '''load lookup tables from data directory''' - data_directory = pkg_resources.resource_filename('gtfsdb', 'data') - RouteType.load(db, data_directory, False) - StopFeatureType.load(db, data_directory, False) - - '''load known files & fields from GTFS''' + '''load known GTFS files, derived tables & lookup tables''' gtfs_directory = self.unzip() - FeedInfo.load(db.engine, gtfs_directory) - Agency.load(db.engine, gtfs_directory) - Calendar.load(db.engine, gtfs_directory) - CalendarDate.load(db.engine, gtfs_directory) - Route.load(db.engine, gtfs_directory) - Stop.load(db.engine, gtfs_directory) - StopFeature.load(db.engine, gtfs_directory) - Transfer.load(db.engine, gtfs_directory) - Shape.load(db.engine, gtfs_directory) - Pattern.load(db) - Trip.load(db.engine, gtfs_directory) - StopTime.load(db.engine, gtfs_directory) - Frequency.load(db.engine, gtfs_directory) - FareAttribute.load(db.engine, gtfs_directory) - FareRule.load(db.engine, gtfs_directory) + load_kwargs = dict( + batch_size=kwargs.get('batch_size', config.DEFAULT_BATCH_SIZE), + gtfs_directory=gtfs_directory, + ) + for cls in db.classes: + cls.load(db, **load_kwargs) shutil.rmtree(gtfs_directory) - UniversalCalendar.load(db) '''load derived geometries, currently only written for PostgreSQL''' if db.is_geospatial and db.is_postgresql: diff --git a/gtfsdb/model/shape.py b/gtfsdb/model/shape.py index f8c60e9..cb0f612 100644 --- a/gtfsdb/model/shape.py +++ b/gtfsdb/model/shape.py @@ -40,7 +40,7 @@ def add_geometry_column(cls): GeometryDDL(cls.__table__) @classmethod - def load(cls, db): + def load(cls, db, **kwargs): start_time = time.time() s = ' - %s' % (cls.__tablename__) sys.stdout.write(s)