Skip to content

Commit

Permalink
Added properties to Database for setting schema, url, etc and appropr…
Browse files Browse the repository at this point in the history
…iate tasks needed after each

Database load now takes keyword args
  • Loading branch information
mgilligan committed Feb 28, 2014
1 parent f509c6d commit e33fdab
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 21 deletions.
9 changes: 6 additions & 3 deletions gtfsdb/api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from gtfsdb import Database, GTFS


def database_load(filename, database_url='sqlite://',
schema=None, is_geospatial=False):
def database_load(filename, **kwargs):
'''Basic API to load a GTFS zip file into a database
arguments:
filename: URL or local path to GTFS zip file
keyword arguments:
batch_size: record batch size for memory management
database_url: SQLAlchemy database url
schema: database schema name
is_geospatial: if database is support geo functions
'''
db = Database(database_url, schema, is_geospatial)

db = Database(**kwargs)
db.create()
gtfs = GTFS(filename)
gtfs.load(db)
60 changes: 47 additions & 13 deletions gtfsdb/model/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,11 @@

class Database(object):

def __init__(self, url, schema=None, is_geospatial=False):
self.url = url
self.schema = schema
self.is_geospatial = is_geospatial
for cls in self.classes:
cls.__table__.schema = schema
if is_geospatial and hasattr(cls, 'add_geometry_column'):
cls.add_geometry_column()
self.engine = create_engine(url)
if 'sqlite' in url:
self.engine.connect().connection.connection.text_factory = str
session_factory = sessionmaker(self.engine)
self.session = scoped_session(session_factory)
def __init__(self, **kwargs):
self.url = kwargs.get('url', config.DEFAULT_DATABASE_URL)
self.schema = kwargs.get('schema', config.DEFAULT_SCHEMA)
self.is_geospatial = kwargs.get('is_geospatial',
config.DEFAULT_IS_GEOSPATIAL)

@property
def classes(self):
Expand All @@ -38,3 +30,45 @@ def dialect_name(self):
def metadata(self):
from gtfsdb.model.base import Base
return Base.metadata

@property
def is_geospatial(self):
return self._is_geospatial

@is_geospatial.setter
def is_geospatial(self, val):
self._is_geospatial = val
for cls in self.classes:
if val and hasattr(cls, 'add_geometry_column'):
cls.add_geometry_column()

@property
def is_postgresql(self):
return 'postgres' in self.dialect_name

@property
def is_sqlite(self):
return 'sqlite' in self.dialect_name

@property
def schema(self):
return self._schema

@schema.setter
def schema(self, val):
self._schema = val
for cls in self.classes:
cls.__table__.schema = val

@property
def url(self):
return self._url

@url.setter
def url(self, val):
self._url = val
self.engine = create_engine(val)
if self.is_sqlite:
self.engine.connect().connection.connection.text_factory = str
session_factory = sessionmaker(self.engine)
self.session = scoped_session(session_factory)
4 changes: 1 addition & 3 deletions gtfsdb/model/gtfs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import closing
import logging
import time
import pkg_resources
import shutil
import sys
import tempfile
Expand Down Expand Up @@ -61,8 +60,7 @@ def load(self, db):
UniversalCalendar.load(db)

'''load derived geometries, currently only written for PostgreSQL'''
dialect_name = db.engine.url.get_dialect().name
if db.is_geospatial and 'postgres' in dialect_name:
if db.is_geospatial and db.is_postgresql:
s = ' - %s geom' % (Route.__tablename__)
sys.stdout.write(s)
start_seconds = time.time()
Expand Down
9 changes: 7 additions & 2 deletions gtfsdb/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,10 @@ def gtfsdb_load():
help='Database SCHEMA name')
args = parser.parse_args()

database_load(
args.file, args.database_url, args.schema, args.is_geospatial)
kwargs = dict(
batch_size=args.batch_size,
database_url=args.database_url,
schema=args.schema,
is_geospatial=args.is_geospatial,
)
database_load(args.file, **kwargs)

0 comments on commit e33fdab

Please sign in to comment.