Skip to content

Commit

Permalink
Invert write and get functions
Browse files Browse the repository at this point in the history
should probably be reflected in the readme... eventually (not critical, no AP change)
  • Loading branch information
Gouvernathor committed May 10, 2024
1 parent a9ad909 commit 7459903
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 44 deletions.
12 changes: 6 additions & 6 deletions src/parliamentarch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from io import TextIOBase

from .geometry import get_seats_centers
from .svg import SeatData, dispatch_seats, write_grouped_svg
from ._util import filter_kwargs, get_from_write
from .svg import SeatData, dispatch_seats, get_grouped_svg
from ._util import filter_kwargs, write_from_get

__all__ = ("get_svg_from_attribution", "write_svg_from_attribution", "SeatData")

_GET_SEATS_CENTERS_PARAM_NAMES = {k for k, p in signature(get_seats_centers).parameters.items() if p.kind==p.KEYWORD_ONLY}
_WRITE_GROUPED_SVG_PARAM_NAMES = {k for k, p in signature(write_grouped_svg).parameters.items() if p.kind==p.KEYWORD_ONLY}
_WRITE_GROUPED_SVG_PARAM_NAMES = {k for k, p in signature(get_grouped_svg).parameters.items() if p.kind==p.KEYWORD_ONLY}

def write_svg_from_attribution(file: TextIOBase, attrib: dict[SeatData, int], **kwargs) -> None:
def get_svg_from_attribution(attrib: dict[SeatData, int], **kwargs) -> str:
nseats = sum(attrib.values())
get_seats_centers_kwargs, write_grouped_svg_kwargs, kwargs = filter_kwargs(_GET_SEATS_CENTERS_PARAM_NAMES, _WRITE_GROUPED_SVG_PARAM_NAMES, **kwargs)

Expand All @@ -19,6 +19,6 @@ def write_svg_from_attribution(file: TextIOBase, attrib: dict[SeatData, int], **

results = get_seats_centers(nseats, **get_seats_centers_kwargs)
seat_centers_by_group = dispatch_seats(attrib, sorted(results, key=results.__getitem__, reverse=True))
write_grouped_svg(file, seat_centers_by_group, results.seat_actual_radius, **write_grouped_svg_kwargs)
return get_grouped_svg(seat_centers_by_group, results.seat_actual_radius, **write_grouped_svg_kwargs)

get_svg_from_attribution = get_from_write(write_svg_from_attribution)
write_svg_from_attribution = write_from_get(get_svg_from_attribution)
31 changes: 19 additions & 12 deletions src/parliamentarch/_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Container, Sequence
import inspect
from io import StringIO
from collections.abc import Callable, Container, Sequence
from inspect import Parameter, signature
from io import TextIOBase
from typing import NamedTuple

class FactoryDict(dict):
Expand Down Expand Up @@ -39,16 +39,23 @@ def from_any(cls, o, /):
return cls(*o)
raise ValueError(f"Cannot convert {o!r} to a {cls.__name__}")

def get_from_write(write_func):
def get(*args, **kwargs) -> str:
sio = StringIO()
write_func(sio, *args, **kwargs)
return sio.getvalue()
_file_parameter = Parameter("file",
Parameter.POSITIONAL_ONLY,
annotation=str | TextIOBase)

write_sig = inspect.signature(write_func)
_, *params = write_sig.parameters.values()
get.__signature__ = write_sig.replace(parameters=params)
return get
def write_from_get(get_func: Callable[..., str]) -> Callable[..., None]:
def write_func(file, /, *args, **kwargs):
if isinstance(file, str):
with open(file, "w") as f:
return write_func(f, *args, **kwargs)
print(get_func(*args, **kwargs), file=file)

get_sig = signature(get_func)
write_sig = get_sig.replace(
parameters=[_file_parameter] + list(get_sig.parameters.values()))
write_func.__signature__ = write_sig
write_func.__name__ = get_func.__name__.replace("get_", "write_")
return write_func

def filter_kwargs[V](
*sets: Container[str],
Expand Down
53 changes: 27 additions & 26 deletions src/parliamentarch/svg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import warnings

from ._util import Color, UnPicklable, get_from_write
from ._util import Color, UnPicklable, write_from_get

__all__ = ("SeatData", "dispatch_seats", "write_svg", "write_grouped_svg", "get_svg", "get_grouped_svg")

Expand Down Expand Up @@ -64,25 +64,23 @@ def dispatch_seats[S](
return rv


def write_svg(
file: TextIOBase,
def get_svg(
seat_centers: dict[tuple[float, float], SeatData],
*args, **kwargs,
) -> None:
) -> str:
seat_centers_by_group = {}
for seat, group in seat_centers.items():
seat_centers_by_group.setdefault(group, []).append(seat)
write_grouped_svg(file, seat_centers_by_group, *args, **kwargs)
return get_grouped_svg(seat_centers_by_group, *args, **kwargs)

def write_grouped_svg(
file: TextIOBase,
def get_grouped_svg(
seat_centers_by_group: dict[SeatData, list[tuple[float, float]]],
seat_actual_radius: float, *,
canvas_size: float = 175,
margins: float|tuple[float, float]|tuple[float, float, float, float] = 5.,
write_number_of_seats: bool = True,
font_size_factor: float = 36/175,
) -> None:
) -> str:
"""
The margins is either a single value for all four sides,
or a (horizontal, vertical) tuple,
Expand All @@ -91,44 +89,47 @@ def write_grouped_svg(
canvas_size is the height and half of the width
of the canvas 2:1 rectangle to which to add the margins.
"""
buffer = []

if isinstance(margins, (int, float)):
margins = (margins, margins, margins, margins)
elif len(margins) == 2:
margins = margins + margins
left_margin, top_margin, right_margin, bottom_margin = margins

_write_svg_header(file,
_append_svg_header(buffer,
width=left_margin+2*canvas_size+right_margin,
height=top_margin+canvas_size+bottom_margin)
if write_number_of_seats:
font_size = round(font_size_factor * canvas_size)
_write_svg_number_of_seats(file, sum(map(len, seat_centers_by_group.values())),
_append_svg_number_of_seats(buffer, sum(map(len, seat_centers_by_group.values())),
x=left_margin+canvas_size, y=top_margin+(canvas_size*170/175), font_size=font_size)
_write_grouped_svg_seats(file, seat_centers_by_group, seat_actual_radius,
_append_grouped_svg_seats(buffer, seat_centers_by_group, seat_actual_radius,
canvas_size=canvas_size, left_margin=left_margin, top_margin=top_margin)
_write_svg_footer(file)
_append_svg_footer(buffer)

return "".join(buffer)

def _write_svg_header(file: TextIOBase, width: float, height: float) -> None:
file.write(f"""\
def _append_svg_header(buffer: list[str], width: float, height: float) -> None:
buffer.append(f"""\
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg xmlns:svg="http://www.w3.org/2000/svg"
xmlns="http://www.w3.org/2000/svg" version="1.1"
width="{width}" height="{height}">
<!-- Created with parliamentarch (https://github.com/Gouvernathor/parliamentarch/) -->""")

def _write_svg_number_of_seats(
file: TextIOBase,
def _append_svg_number_of_seats(
buffer: list[str],
nseats: int,
x: float, y: float,
font_size: int,
) -> None:
file.write(f"""
buffer.append(f"""
<text x="{x}" y="{y}"
style="font-size:{font_size}px;font-weight:bold;text-align:center;text-anchor:middle;font-family:sans-serif">{nseats}</text>""")

def _write_grouped_svg_seats(
file: TextIOBase,
def _append_grouped_svg_seats(
buffer: list[str],
seat_centers_by_group: dict[SeatData, list[tuple[float, float]]],
seat_actual_radius: float,
canvas_size: float,
Expand All @@ -154,7 +155,7 @@ def _write_grouped_svg_seats(
if isinstance(group_border_color, Color):
group_border_color = group_border_color.hexcode

file.write(f"""
buffer.append(f"""
<g style="fill:{group_color}; stroke-width:{group_border_width:.2f}; stroke:{group_border_color}"
id="{block_id}">
<title>{group.data}</title>""")
Expand All @@ -163,17 +164,17 @@ def _write_grouped_svg_seats(
actual_x = left_margin + canvas_size * x
actual_y = top_margin + canvas_size * (1 - y)
actual_radius = seat_actual_radius * canvas_size - group_border_width/2
file.write(f"""
buffer.append(f"""
<circle cx="{actual_x:.2f}" cy="{actual_y:.2f}" r="{actual_radius:.2f}"/>""")

file.write("""
buffer.append("""
</g>""")

def _write_svg_footer(file: TextIOBase) -> None:
file.write("""
def _append_svg_footer(buffer: list[str]) -> None:
buffer.append("""
</svg>
""")


get_svg = get_from_write(write_svg)
get_grouped_svg = get_from_write(write_grouped_svg)
write_svg = write_from_get(get_svg)
write_grouped_svg = write_from_get(get_grouped_svg)

0 comments on commit 7459903

Please sign in to comment.