diff --git a/src/parliamentarch/__init__.py b/src/parliamentarch/__init__.py index 7deb213..1f866db 100644 --- a/src/parliamentarch/__init__.py +++ b/src/parliamentarch/__init__.py @@ -2,15 +2,15 @@ from io import TextIOBase from .geometry import get_seats_centers -from .svg import SeatData, dispatch_seats, write_grouped_svg -from ._util import filter_kwargs, get_from_write +from .svg import SeatData, dispatch_seats, get_grouped_svg +from ._util import filter_kwargs, write_from_get __all__ = ("get_svg_from_attribution", "write_svg_from_attribution", "SeatData") _GET_SEATS_CENTERS_PARAM_NAMES = {k for k, p in signature(get_seats_centers).parameters.items() if p.kind==p.KEYWORD_ONLY} -_WRITE_GROUPED_SVG_PARAM_NAMES = {k for k, p in signature(write_grouped_svg).parameters.items() if p.kind==p.KEYWORD_ONLY} +_WRITE_GROUPED_SVG_PARAM_NAMES = {k for k, p in signature(get_grouped_svg).parameters.items() if p.kind==p.KEYWORD_ONLY} -def write_svg_from_attribution(file: TextIOBase, attrib: dict[SeatData, int], **kwargs) -> None: +def get_svg_from_attribution(attrib: dict[SeatData, int], **kwargs) -> str: nseats = sum(attrib.values()) get_seats_centers_kwargs, write_grouped_svg_kwargs, kwargs = filter_kwargs(_GET_SEATS_CENTERS_PARAM_NAMES, _WRITE_GROUPED_SVG_PARAM_NAMES, **kwargs) @@ -19,6 +19,6 @@ def write_svg_from_attribution(file: TextIOBase, attrib: dict[SeatData, int], ** results = get_seats_centers(nseats, **get_seats_centers_kwargs) seat_centers_by_group = dispatch_seats(attrib, sorted(results, key=results.__getitem__, reverse=True)) - write_grouped_svg(file, seat_centers_by_group, results.seat_actual_radius, **write_grouped_svg_kwargs) + return get_grouped_svg(seat_centers_by_group, results.seat_actual_radius, **write_grouped_svg_kwargs) -get_svg_from_attribution = get_from_write(write_svg_from_attribution) +write_svg_from_attribution = write_from_get(get_svg_from_attribution) diff --git a/src/parliamentarch/_util.py b/src/parliamentarch/_util.py index c253b37..d255982 100644 --- a/src/parliamentarch/_util.py +++ b/src/parliamentarch/_util.py @@ -1,6 +1,6 @@ -from collections.abc import Container, Sequence -import inspect -from io import StringIO +from collections.abc import Callable, Container, Sequence +from inspect import Parameter, signature +from io import TextIOBase from typing import NamedTuple class FactoryDict(dict): @@ -39,16 +39,23 @@ def from_any(cls, o, /): return cls(*o) raise ValueError(f"Cannot convert {o!r} to a {cls.__name__}") -def get_from_write(write_func): - def get(*args, **kwargs) -> str: - sio = StringIO() - write_func(sio, *args, **kwargs) - return sio.getvalue() +_file_parameter = Parameter("file", + Parameter.POSITIONAL_ONLY, + annotation=str | TextIOBase) - write_sig = inspect.signature(write_func) - _, *params = write_sig.parameters.values() - get.__signature__ = write_sig.replace(parameters=params) - return get +def write_from_get(get_func: Callable[..., str]) -> Callable[..., None]: + def write_func(file, /, *args, **kwargs): + if isinstance(file, str): + with open(file, "w") as f: + return write_func(f, *args, **kwargs) + print(get_func(*args, **kwargs), file=file) + + get_sig = signature(get_func) + write_sig = get_sig.replace( + parameters=[_file_parameter] + list(get_sig.parameters.values())) + write_func.__signature__ = write_sig + write_func.__name__ = get_func.__name__.replace("get_", "write_") + return write_func def filter_kwargs[V]( *sets: Container[str], diff --git a/src/parliamentarch/svg.py b/src/parliamentarch/svg.py index 261666a..b29803d 100644 --- a/src/parliamentarch/svg.py +++ b/src/parliamentarch/svg.py @@ -4,7 +4,7 @@ import re import warnings -from ._util import Color, UnPicklable, get_from_write +from ._util import Color, UnPicklable, write_from_get __all__ = ("SeatData", "dispatch_seats", "write_svg", "write_grouped_svg", "get_svg", "get_grouped_svg") @@ -64,25 +64,23 @@ def dispatch_seats[S]( return rv -def write_svg( - file: TextIOBase, +def get_svg( seat_centers: dict[tuple[float, float], SeatData], *args, **kwargs, - ) -> None: + ) -> str: seat_centers_by_group = {} for seat, group in seat_centers.items(): seat_centers_by_group.setdefault(group, []).append(seat) - write_grouped_svg(file, seat_centers_by_group, *args, **kwargs) + return get_grouped_svg(seat_centers_by_group, *args, **kwargs) -def write_grouped_svg( - file: TextIOBase, +def get_grouped_svg( seat_centers_by_group: dict[SeatData, list[tuple[float, float]]], seat_actual_radius: float, *, canvas_size: float = 175, margins: float|tuple[float, float]|tuple[float, float, float, float] = 5., write_number_of_seats: bool = True, font_size_factor: float = 36/175, - ) -> None: + ) -> str: """ The margins is either a single value for all four sides, or a (horizontal, vertical) tuple, @@ -91,6 +89,7 @@ def write_grouped_svg( canvas_size is the height and half of the width of the canvas 2:1 rectangle to which to add the margins. """ + buffer = [] if isinstance(margins, (int, float)): margins = (margins, margins, margins, margins) @@ -98,37 +97,39 @@ def write_grouped_svg( margins = margins + margins left_margin, top_margin, right_margin, bottom_margin = margins - _write_svg_header(file, + _append_svg_header(buffer, width=left_margin+2*canvas_size+right_margin, height=top_margin+canvas_size+bottom_margin) if write_number_of_seats: font_size = round(font_size_factor * canvas_size) - _write_svg_number_of_seats(file, sum(map(len, seat_centers_by_group.values())), + _append_svg_number_of_seats(buffer, sum(map(len, seat_centers_by_group.values())), x=left_margin+canvas_size, y=top_margin+(canvas_size*170/175), font_size=font_size) - _write_grouped_svg_seats(file, seat_centers_by_group, seat_actual_radius, + _append_grouped_svg_seats(buffer, seat_centers_by_group, seat_actual_radius, canvas_size=canvas_size, left_margin=left_margin, top_margin=top_margin) - _write_svg_footer(file) + _append_svg_footer(buffer) + + return "".join(buffer) -def _write_svg_header(file: TextIOBase, width: float, height: float) -> None: - file.write(f"""\ +def _append_svg_header(buffer: list[str], width: float, height: float) -> None: + buffer.append(f"""\ """) -def _write_svg_number_of_seats( - file: TextIOBase, +def _append_svg_number_of_seats( + buffer: list[str], nseats: int, x: float, y: float, font_size: int, ) -> None: - file.write(f""" + buffer.append(f""" {nseats}""") -def _write_grouped_svg_seats( - file: TextIOBase, +def _append_grouped_svg_seats( + buffer: list[str], seat_centers_by_group: dict[SeatData, list[tuple[float, float]]], seat_actual_radius: float, canvas_size: float, @@ -154,7 +155,7 @@ def _write_grouped_svg_seats( if isinstance(group_border_color, Color): group_border_color = group_border_color.hexcode - file.write(f""" + buffer.append(f""" {group.data}""") @@ -163,17 +164,17 @@ def _write_grouped_svg_seats( actual_x = left_margin + canvas_size * x actual_y = top_margin + canvas_size * (1 - y) actual_radius = seat_actual_radius * canvas_size - group_border_width/2 - file.write(f""" + buffer.append(f""" """) - file.write(""" + buffer.append(""" """) -def _write_svg_footer(file: TextIOBase) -> None: - file.write(""" +def _append_svg_footer(buffer: list[str]) -> None: + buffer.append(""" """) -get_svg = get_from_write(write_svg) -get_grouped_svg = get_from_write(write_grouped_svg) +write_svg = write_from_get(get_svg) +write_grouped_svg = write_from_get(get_grouped_svg)