Skip to content

Commit

Permalink
- Added length to String column definitions, required for MySQL support
Browse files Browse the repository at this point in the history
- Added optional & proposed columns
  - fare_attributes.txt: agency_id
  - feed_info.txt: feed_license
  - routes.txt: route_desc
  - stops.txt: stop_timezone, platform_code
- Simplified validation process
  • Loading branch information
mgilligan committed Jul 26, 2013
1 parent a0eaec3 commit c167b2e
Show file tree
Hide file tree
Showing 14 changed files with 112 additions and 213 deletions.
21 changes: 9 additions & 12 deletions gtfsdb/model/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,12 @@
class Agency(Base):
__tablename__ = 'agency'

required_fields = ['agency_name', 'agency_url', 'agency_timezone']
optional_fields = ['agency_id', 'agency_lang', 'agency_phone',
'agency_fare_url']

id = Column(Integer, Sequence(None, optional=True), primary_key=True)
agency_id = Column(String, index=True, unique=True)
agency_name = Column(String, nullable=False)
agency_url = Column(String, nullable=False)
agency_timezone = Column(String, nullable=False)
agency_lang = Column(String)
agency_phone = Column(String)
agency_fare_url = Column(String)
id = Column(Integer,
Sequence(None, optional=True), primary_key=True, nullable=True)
agency_id = Column(String(255), index=True, unique=True)
agency_name = Column(String(255), nullable=False)
agency_url = Column(String(255), nullable=False)
agency_timezone = Column(String(50), nullable=False)
agency_lang = Column(String(10))
agency_phone = Column(String(50))
agency_fare_url = Column(String(255))
57 changes: 23 additions & 34 deletions gtfsdb/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import datetime
import logging
import os
import sys
import time
Expand All @@ -9,11 +10,10 @@
from gtfsdb import util


class _Base(object):
log = logging.getLogger(__name__)


required_fields = []
optional_fields = []
proposed_fields = []
class _Base(object):

@classmethod
def from_dict(cls, attrs):
Expand Down Expand Up @@ -50,13 +50,11 @@ def load(cls, engine, directory=None, validate=True):
file_path = '%s/%s' % (directory, cls.get_filename())
if os.path.exists(file_path):
start_time = time.time()
file = open(file_path, 'r')
utf8_file = util.UTF8Recoder(file, 'utf-8-sig')
f = open(file_path, 'r')
utf8_file = util.UTF8Recoder(f, 'utf-8-sig')
reader = csv.DictReader(utf8_file)
if validate:
cls.validate(reader.fieldnames)
s = ' - %s ' % (cls.get_filename())
sys.stdout.write(s)
table = cls.__table__
engine.execute(table.delete())
i = 0
Expand All @@ -71,9 +69,10 @@ def load(cls, engine, directory=None, validate=True):
i = 0
if len(records) > 0:
engine.execute(table.insert(), records)
file.close()
f.close()
processing_time = time.time() - start_time
print ' (%.0f seconds)' % (processing_time)
log.debug('{0} ({1:.0f} seconds)'.format(
cls.get_filename(), processing_time))

@classmethod
def make_record(cls, row):
Expand All @@ -92,31 +91,21 @@ def make_record(cls, row):
cls.add_geom_to_dict(row)
return row

@classmethod
def set_schema(cls, schema):
cls.__table__.schema = schema

@classmethod
def validate(cls, fieldnames):
all_fields = cls.required_fields + cls.optional_fields + cls.proposed_fields

# required fields
fields = None
if cls.required_fields and fieldnames:
fields = set(cls.required_fields) - set(fieldnames)
if fields:
missing_required_fields = list(fields)
if missing_required_fields:
print ' %s missing fields: %s' % (cls.get_filename(), missing_required_fields)

# all fields
fields = None
if all_fields and fieldnames:
fields = set(fieldnames) - set(all_fields)
if fields:
unknown_fields = list(fields)
if unknown_fields:
print ' %s unknown fields: %s' % (cls.get_filename(), unknown_fields)

if not fieldnames:
return
cols = cls.__table__.columns
all_fields = [c.name for c in cols]
required_fields = [c.name for c in cols if c.nullable == False]
missing_fields = list(set(required_fields) - set(fieldnames))
unknown_fields = list(set(fieldnames) - set(all_fields))

if missing_fields:
log.debug('{0} missing fields: {1}'.format(
cls.get_filename(), missing_fields))
if unknown_fields:
log.debug('{0} unknown fields: {1}'.format(
cls.get_filename(), unknown_fields))

Base = declarative_base(cls=_Base)
33 changes: 7 additions & 26 deletions gtfsdb/model/calendar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,7 @@ class Calendar(Base):
__tablename__ = 'calendar'
__table_args__ = (Index('calendar_ix1', 'start_date', 'end_date'),)

required_fields = [
'service_id',
'monday',
'tuesday',
'wednesday',
'thursday',
'friday',
'saturday',
'sunday',
'start_date',
'end_date'
]

service_id = Column(String, primary_key=True, nullable=False)
service_id = Column(String(255), primary_key=True, nullable=False)
monday = Column(Boolean, nullable=False)
tuesday = Column(Boolean, nullable=False)
wednesday = Column(Boolean, nullable=False)
Expand Down Expand Up @@ -79,10 +66,8 @@ def to_date_list(self):
class CalendarDate(Base):
__tablename__ = 'calendar_dates'

required_fields = ['service_id', 'date', 'exception_type']

service_id = Column(String, primary_key=True)
date = Column(Date, primary_key=True, index=True)
service_id = Column(String(255), primary_key=True, nullable=False)
date = Column(Date, primary_key=True, index=True, nullable=False)
exception_type = Column(Integer, nullable=False)

@property
Expand All @@ -97,10 +82,8 @@ def is_removal(self):
class UniversalCalendar(Base):
__tablename__ = 'universal_calendar'

required_fields = ['service_id', 'date']

service_id = Column(String, primary_key=True)
date = Column(Date, primary_key=True, index=True)
service_id = Column(String(255), primary_key=True, nullable=False)
date = Column(Date, primary_key=True, index=True, nullable=False)

trips = relationship('Trip',
primaryjoin='Trip.service_id==UniversalCalendar.service_id',
Expand All @@ -126,10 +109,8 @@ def load(cls, engine):
session = Session()
q = session.query(Calendar)
for calendar in q:
rows = calendar.to_date_list()
for row in rows:
uc = cls(**row)
session.add(uc)
for row in calendar.to_date_list():
session.add(cls(**row))
session.commit()
q = session.query(CalendarDate)
for calendar_date in q:
Expand Down
32 changes: 9 additions & 23 deletions gtfsdb/model/fare.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,23 @@
class FareAttribute(Base):
__tablename__ = 'fare_attributes'

required_fields = [
'fare_id',
'price',
'currency_type',
'payment_method',
'transfers'
]
optional_fields = ['transfer_duration']
proposed_fields = ['agency_id']

fare_id = Column(String, primary_key=True)
fare_id = Column(String(255), primary_key=True)
price = Column(Numeric(10, 2), nullable=False)
currency_type = Column(String, nullable=False)
currency_type = Column(String(255), nullable=False)
payment_method = Column(Integer, nullable=False)
transfers = Column(Integer)
transfer_duration = Column(Integer)
agency_id = Column(String(255))


class FareRule(Base):
__tablename__ = 'fare_rules'

required_fields = ['fare_id']
optional_fields = ['route_id', 'origin_id',
'destination_id', 'contains_id']
proposed_fields = ['service_id']

id = Column(Integer, Sequence(None, optional=True), primary_key=True)
fare_id = Column(String,
fare_id = Column(String(255),
ForeignKey('fare_attributes.fare_id'), index=True, nullable=False)
route_id = Column(String)
origin_id = Column(String)
destination_id = Column(String)
contains_id = Column(String)
service_id = Column(String)
route_id = Column(String(255))
origin_id = Column(String(255))
destination_id = Column(String(255))
contains_id = Column(String(255))
service_id = Column(String(255))
13 changes: 5 additions & 8 deletions gtfsdb/model/feed_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@
class FeedInfo(Base):
__tablename__ = 'feed_info'

required_fields = ['feed_publisher_name', 'feed_publisher_url',
'feed_lang']
optional_fields = ['feed_start_date', 'feed_end_date', 'feed_version']

feed_publisher_name = Column(String, primary_key=True)
feed_publisher_url = Column(String, nullable=False)
feed_lang = Column(String, nullable=False)
feed_publisher_name = Column(String(255), primary_key=True)
feed_publisher_url = Column(String(255), nullable=False)
feed_lang = Column(String(255), nullable=False)
feed_start_date = Column(Date)
feed_end_date = Column(Date)
feed_version = Column(String)
feed_version = Column(String(255))
feed_license = Column(String(255))
10 changes: 4 additions & 6 deletions gtfsdb/model/frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@
class Frequency(Base):
__tablename__ = 'frequencies'

required_fields = ['trip_id', 'start_time', 'end_time', 'headway_secs']
proposed_fields = ['exact_times']

trip_id = Column(String, ForeignKey('trips.trip_id'), primary_key=True)
start_time = Column(String, primary_key=True)
end_time = Column(String)
trip_id = Column(
String(255), ForeignKey('trips.trip_id'), primary_key=True)
start_time = Column(String(8), primary_key=True)
end_time = Column(String(8))
headway_secs = Column(Integer)
exact_times = Column(Integer)

Expand Down
3 changes: 1 addition & 2 deletions gtfsdb/model/gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def validate(self):
"""Run transitfeed.feedvalidator"""
path = os.path.join(
pkg_resources.get_distribution('transitfeed').egg_info,
'scripts/feedvalidator.py'
)
'scripts/feedvalidator.py')

stdout, stderr = subprocess.Popen(
[sys.executable, path, '--output=CONSOLE', self.local_file],
Expand Down
27 changes: 7 additions & 20 deletions gtfsdb/model/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,22 @@ class RouteType(Base):
__tablename__ = 'route_type'

route_type = Column(Integer, primary_key=True)
route_type_name = Column(String)
route_type_desc = Column(String)
route_type_name = Column(String(255))
route_type_desc = Column(String(255))


class Route(Base):
__tablename__ = 'routes'

required_fields = [
'route_id',
'route_short_name',
'route_long_name',
'route_type'
]
optional_fields = [
'agency_id',
'route_desc',
'route_url',
'route_color',
'route_text_color'
]

route_id = Column(String, primary_key=True, nullable=False)
route_id = Column(String(255), primary_key=True, nullable=False)
agency_id = Column(
String, ForeignKey('agency.agency_id'), index=True, nullable=True)
route_short_name = Column(String)
route_long_name = Column(String)
route_short_name = Column(String(255))
route_long_name = Column(String(255))
route_desc = Column(String(255))
route_type = Column(Integer,
ForeignKey('route_type.route_type'), index=True, nullable=False)
route_url = Column(String)
route_url = Column(String(255))
route_color = Column(String(6))
route_text_color = Column(String(6))

Expand Down
30 changes: 12 additions & 18 deletions gtfsdb/model/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class Pattern(Base):
__tablename__ = 'patterns'

shape_id = Column(String, primary_key=True)
shape_id = Column(String(255), primary_key=True)
pattern_dist = Column(Numeric(20, 10))

trips = relationship('Trip')
Expand Down Expand Up @@ -72,15 +72,7 @@ def load(cls, engine):
class Shape(Base):
__tablename__ = 'shapes'

required_fields = [
'shape_id',
'shape_pt_lat',
'shape_pt_lon',
'shape_pt_sequence'
]
optional_fields = ['shape_dist_traveled']

shape_id = Column(String, primary_key=True)
shape_id = Column(String(255), primary_key=True)
shape_pt_lat = Column(Numeric(12, 9))
shape_pt_lon = Column(Numeric(12, 9))
shape_pt_sequence = Column(Integer, primary_key=True)
Expand All @@ -95,11 +87,13 @@ def add_geometry_column(cls):

@classmethod
def add_geom_to_dict(cls, row):
from geoalchemy import WKTSpatialElement

wkt = 'SRID=%s;POINT(%s %s)' % (
SRID,
row['shape_pt_lon'],
row['shape_pt_lat']
)
row['geom'] = WKTSpatialElement(wkt)
try:
from geoalchemy import WKTSpatialElement
wkt = 'SRID=%s;POINT(%s %s)' % (
SRID,
row['shape_pt_lon'],
row['shape_pt_lat']
)
row['geom'] = WKTSpatialElement(wkt)
except ImportError:
pass
21 changes: 9 additions & 12 deletions gtfsdb/model/stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,19 @@
class Stop(Base):
__tablename__ = 'stops'

required_fields = ['stop_id', 'stop_name', 'stop_lat', 'stop_lon']
optional_fields = ['stop_code', 'stop_desc', 'zone_id', 'stop_url',
'location_type', 'parent_station',
'wheelchair_boarding']

stop_id = Column(String, primary_key=True, nullable=False)
stop_code = Column(String)
stop_name = Column(String, nullable=False)
stop_desc = Column(String)
stop_id = Column(String(255), primary_key=True, nullable=False)
stop_code = Column(String(50))
stop_name = Column(String(255), nullable=False)
stop_desc = Column(String(255))
stop_lat = Column(Numeric(12, 9), nullable=False)
stop_lon = Column(Numeric(12, 9), nullable=False)
zone_id = Column(String)
stop_url = Column(String)
zone_id = Column(String(50))
stop_url = Column(String(255))
location_type = Column(Integer, index=True, default=0)
parent_station = Column(String)
parent_station = Column(String(255))
stop_timezone = Column(String(50))
wheelchair_boarding = Column(Integer, default=0)
platform_code = Column(String(50))

stop_features = relationship('StopFeature')
stop_times = relationship('StopTime')
Expand Down
Loading

0 comments on commit c167b2e

Please sign in to comment.