Skip to content

Commit

Permalink
chore: better train search (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
2mal3 authored Nov 23, 2023
1 parent 227f717 commit 47a2e8d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion templates/api/trains_search.html
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{% for train in trains %}
<option value="{{ train.number }}">{{ train.type }} {{ train.number }}</option>
<option>{{ train.type }} {{ train.number }}</option>
{% endfor %}
2 changes: 1 addition & 1 deletion templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ <h1>VerspätungsOrakel</h1>
<div>
<label for="train">Zugnummer</label>
<input
type="number"
type="text"
id="train"
name="train"
list="trains"
Expand Down
2 changes: 2 additions & 0 deletions verspaetungsorakel/bahn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def get_important_stations():
stations = list(filter(lambda s: s.Verkehr == "FV", stations))
stations = list(filter(lambda s: "Hbf" in s.NAME, stations))

stations = list(filter(lambda s: "Berlin" in s.NAME, stations))

return stations


Expand Down
4 changes: 3 additions & 1 deletion verspaetungsorakel/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ def sqlite_wal_mode(db, connection: sqlite3.Connection):


class Train(db.Entity):
number = PrimaryKey(str)
number = Required(str)
type = Required(str)

PrimaryKey(number, type)

trips = Set(lambda: Trip)


Expand Down
23 changes: 19 additions & 4 deletions verspaetungsorakel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,21 @@ def get_stations(station: str, request: Request):

@app.get("/api/v1/trains")
@limiter.limit("60/minute")
def get_trains(train: int, request: Request):
def get_trains(train: str, request: Request):
t_type, t_number = upack_train(train)

with db_session:
trains = to_dicts(Train.select(lambda t: str(train) in t.number).limit(100)[:])
trains = to_dicts(Train.select(lambda t: t_number in t.number and t_type in t.type).limit(100)[:])
return templates.TemplateResponse("api/trains_search.html", {"request": request, "trains": trains})


@app.get("/api/v1/data")
@limiter.limit("1/second")
def search(station: str, train: int, request: Request):
def search(station: str, train: str, request: Request):
t_type, t_number = upack_train(train)

with db_session:
db_train = Train[str(train)]
db_train = Train.get(number=t_number, type=t_type)
if not db_train:
raise HTTPException(status_code=404, detail="Train not found")
db_station = Station.get(name=station)
Expand Down Expand Up @@ -97,3 +101,14 @@ def search(station: str, train: int, request: Request):

def to_dicts(entities: list) -> list:
return [e.to_dict() for e in entities]


def upack_train(train: str) -> tuple:
t_args = train.split(" ")
if len(t_args) != 2:
raise HTTPException(status_code=400, detail="Invalid train number")

t_type = t_args[0].upper()
t_number = t_args[1]

return t_type, t_number

0 comments on commit 47a2e8d

Please sign in to comment.