diff --git a/gtfsdb/model/db.py b/gtfsdb/model/db.py index d96f0ed..fc17f14 100644 --- a/gtfsdb/model/db.py +++ b/gtfsdb/model/db.py @@ -7,6 +7,12 @@ class Database(object): def __init__(self, **kwargs): + ''' + keyword arguments: + url: SQLAlchemy database url + schema: database schema name + is_geospatial: if database supports geo functions + ''' self.url = kwargs.get('url', config.DEFAULT_DATABASE_URL) self.schema = kwargs.get('schema', config.DEFAULT_SCHEMA) self.is_geospatial = kwargs.get('is_geospatial', @@ -60,6 +66,16 @@ def schema(self, val): for cls in self.classes: cls.__table__.schema = val + @property + def sorted_classes(self): + classes = [] + for t in self.metadata.sorted_tables: + cls = next((c for c in self.classes + if c.__table__ == t), None) + if cls: + classes.append(cls) + return classes + @property def url(self): return self._url diff --git a/gtfsdb/model/gtfs.py b/gtfsdb/model/gtfs.py index c872a33..15fab23 100644 --- a/gtfsdb/model/gtfs.py +++ b/gtfsdb/model/gtfs.py @@ -27,7 +27,7 @@ def load(self, db, **kwargs): batch_size=kwargs.get('batch_size', config.DEFAULT_BATCH_SIZE), gtfs_directory=gtfs_directory, ) - for cls in db.classes: + for cls in db.sorted_classes: cls.load(db, **load_kwargs) shutil.rmtree(gtfs_directory)