From 2077fe0bfe4e4e71554d6d364c2b4a30c9729863 Mon Sep 17 00:00:00 2001 From: Cyrill Raccaud Date: Mon, 20 Nov 2023 22:12:21 +0100 Subject: [PATCH] Extend api capabilities (#16) * Extend the existing api capabilities - connections: add all missing parameters - stationboard: add all missing parameters - urlencode: Add "True" flag for correct formatting * Add locations endpoint and update example * update version to 0.4.0 and update changelog * fix black linting * Fix AttributeError * Update changelog entry --------- Co-authored-by: Fabian Affolter --- CHANGES.txt | 14 +++ example.py | 23 +++++ opendata_transport/__init__.py | 173 +++++++++++++++++++++++++++++++-- 3 files changed, 202 insertions(+), 8 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index 851c04e..cecf799 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,6 +1,20 @@ Changes ======= +20231120 - 0.4.0 +---------------- + +- Add support locations API (thanks @miaucl) +- Add missing connections parameters (thanks @miaucl) +- Add missing stationboard parameters (thanks @miaucl) +- Add "True" flag for correct formatting of lists (ex. via parameter `via[]=foo1&via[]=foo2`) (thanks @miaucl) + +20211124 - 0.3.0 +---------------- + +- Don't use async timeout (thanks @agners) +- Remove loop + 20210317 - 0.2.2 ---------------- diff --git a/example.py b/example.py index 13272e9..55166bd 100644 --- a/example.py +++ b/example.py @@ -5,11 +5,34 @@ from opendata_transport import OpendataTransport from opendata_transport import OpendataTransportStationboard +from opendata_transport import OpendataTransportLocation async def main(): """Example for getting the data.""" async with aiohttp.ClientSession() as session: + # Search a station by query + locations = OpendataTransportLocation(session, query="Stettb") + await locations.async_get_data() + + # Print the locations data + print(locations.locations) + + # Print as list + print(list(map(lambda x: x["name"], locations.locations))) + + # Search a station by coordinates + locations = OpendataTransportLocation(session, x=47.2, y=8.7) + await locations.async_get_data() + + # Print the locations data + print(locations.locations) + + # Print as list + print(list(map(lambda x: x["name"], locations.locations))) + + print() + # Get the connection for a defined route connection = OpendataTransport( "Zürich, Blumenfeldstrasse", "Zürich Oerlikon, Bahnhof", session, 4 diff --git a/opendata_transport/__init__.py b/opendata_transport/__init__.py index 24b915d..b154ce8 100644 --- a/opendata_transport/__init__.py +++ b/opendata_transport/__init__.py @@ -20,22 +20,110 @@ def __init__(self, session): @staticmethod def get_url(resource, params): """Generate the URL for the request.""" - param = urllib.parse.urlencode(params) + param = urllib.parse.urlencode(params, True) url = "{resource_url}{resource}?{param}".format( resource_url=_RESOURCE_URL, resource=resource, param=param ) + print(url) return url +class OpendataTransportLocation(OpendataTransportBase): + """A class for handling locations from Opendata Transport.""" + + def __init__(self, session, query=None, x=None, y=None, type_="all", fields=None): + """Initialize the location.""" + super().__init__(session) + + self.query = query + self.x = x + self.y = y + self.type = type_ + self.fields = ( + fields if fields is not None and isinstance(fields, list) else None + ) + + self.from_name = self.from_id = self.to_name = self.to_id = None + + self.locations = [] + + @staticmethod + def get_station(station): + """Get the station details.""" + return { + "name": station["name"], + "score": station["score"], + "coordinate_type": station["coordinate"]["type"], + "x": station["coordinate"]["x"], + "y": station["coordinate"]["y"], + "distance": station["distance"], + } + + async def async_get_data(self): + """Retrieve the data for the location.""" + params = {} + if self.query is not None: + params["query"] = self.query + else: + params["x"] = self.x + params["y"] = self.y + + if self.fields: + params["fields"] = self.fields + + url = self.get_url("locations", params) + + try: + response = await self._session.get(url, raise_for_status=True) + + _LOGGER.debug("Response from transport.opendata.ch: %s", response.status) + data = await response.json() + _LOGGER.debug(data) + except asyncio.TimeoutError: + _LOGGER.error("Can not load data from transport.opendata.ch") + raise exceptions.OpendataTransportConnectionError() + except aiohttp.ClientError as aiohttpClientError: + _LOGGER.error("Response from transport.opendata.ch: %s", aiohttpClientError) + raise exceptions.OpendataTransportConnectionError() + + try: + for station in data["stations"]: + self.locations.append(self.get_station(station)) + except (TypeError, IndexError): + raise exceptions.OpendataTransportError() + + class OpendataTransportStationboard(OpendataTransportBase): """A class for handling stationsboards from Opendata Transport.""" - def __init__(self, station, session, limit=5): + def __init__( + self, + station, + session, + limit=5, + transportations=None, + datetime=None, + type_="departure", + fields=None, + ): """Initialize the journey.""" super().__init__(session) + self.station = station self.limit = limit + self.datetime = datetime + self.transportations = ( + transportations + if transportations is not None and isinstance(transportations, list) + else None + ) + self.type = type_ + self.fields = ( + fields if fields is not None and isinstance(fields, list) else None + ) + self.from_name = self.from_id = self.to_name = self.to_id = None + self.journeys = [] @staticmethod @@ -53,11 +141,20 @@ def get_journey(journey): async def __async_get_data(self, station): """Retrieve the data for the station.""" - params = {"limit": self.limit} + params = { + "limit": self.limit, + "type": self.type, + } if str.isdigit(station): params["id"] = station else: params["station"] = station + if self.datetime: + params["datetime"] = self.date + if self.transportations: + params["transportations"] = self.transportations + if self.fields: + params["fields"] = self.fields url = self.get_url("stationboard", params) @@ -94,13 +191,52 @@ async def async_get_data(self): class OpendataTransport(OpendataTransportBase): """A class for handling connections from Opendata Transport.""" - def __init__(self, start, destination, session, limit=3): + def __init__( + self, + start, + destination, + session, + limit=3, + page=0, + date=None, + time=None, + isArrivalTime=False, + transportations=None, + direct=False, + sleeper=False, + couchette=False, + bike=False, + accessibility=None, + via=None, + fields=None, + ): """Initialize the connection.""" super().__init__(session) + self.limit = limit + self.page = page self.start = start self.destination = destination + self.via = via[:5] if via is not None and isinstance(via, list) else None + self.date = date + self.time = time + self.isArrivalTime = 1 if isArrivalTime else 0 + self.transportations = ( + transportations + if transportations is not None and isinstance(transportations, list) + else None + ) + self.direct = 1 if direct else 0 + self.sleeper = 1 if sleeper else 0 + self.couchette = 1 if couchette else 0 + self.bike = 1 if bike else 0 + self.accessibility = accessibility + self.fields = ( + fields if fields is not None and isinstance(fields, list) else None + ) + self.from_name = self.from_id = self.to_name = self.to_id = None + self.connections = dict() @staticmethod @@ -125,10 +261,31 @@ def get_connection(connection): async def async_get_data(self): """Retrieve the data for the connection.""" - url = self.get_url( - "connections", - {"from": self.start, "to": self.destination, "limit": self.limit}, - ) + params = { + "from": self.start, + "to": self.destination, + "limit": self.limit, + "page": self.page, + "isArrivalTime": self.isArrivalTime, + "direct": self.direct, + "sleeper": self.sleeper, + "couchette": self.couchette, + "bike": self.bike, + } + if self.via: + params["via"] = self.via + if self.time: + params["time"] = self.time + if self.date: + params["date"] = self.date + if self.transportations: + params["transportations"] = self.transportations + if self.accessibility: + params["accessibility"] = self.accessibility + if self.fields: + params["fields"] = self.fields + + url = self.get_url("connections", params) try: response = await self._session.get(url, raise_for_status=True)