Skip to content

Commit

Permalink
feat(convert): use dataclasses for embed files and requirements (#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo authored Nov 28, 2024
1 parent fe104fc commit 3ae7eb7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
76 changes: 51 additions & 25 deletions crunch/convert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import dataclasses
import logging
import os
import re
import typing

import redbaron
import yaml
from requirements.requirement import Requirement

import requirements

Expand All @@ -31,6 +31,20 @@
_KV_DIVIDER = "---"


@dataclasses.dataclass()
class EmbedFile:
path: str
normalized_path: str
content: str


@dataclasses.dataclass()
class Requirement:
name: str
extras: typing.Optional[typing.List[str]]
specs: typing.Optional[typing.List[str]]


def strip_packages(name: str):
if name.startswith(_DOT):
return None # just in case, but should not happen
Expand Down Expand Up @@ -143,13 +157,13 @@ def _convert_import(log: typing.Callable[[str], None], node: redbaron.Node):

version = _extract_import_version(log, node)

packages = set()
names = set()
for path in paths:
name = strip_packages(path)
if name:
packages.add(name)
names.add(name)

return packages, version
return names, version


def _add_to_packages(log: typing.Callable[[str], None], packages: dict, node: redbaron.Node):
Expand Down Expand Up @@ -178,7 +192,7 @@ def _extract_code_cell(
cell_source: typing.List[str],
log: typing.Callable[[str], None],
module: typing.List[str],
packages: typing.Dict[str, Requirement],
packages: typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]],
):
source = "\n".join(
re.sub(r"^\s*?(!|%)", r"#\1", line)
Expand Down Expand Up @@ -229,12 +243,15 @@ def _extract_code_cell(
def _extract_markdown_cell(
cell_source: typing.List[str],
log: typing.Callable[[str], None],
embed_files: typing.Dict[str, str],
embed_files: typing.Dict[str, EmbedFile],
):
if not len(cell_source):
log(f"skip since empty")
return

def get_full_source():
return "\n".join(cell_source)

iterator = iter(cell_source)

if next(iterator) != _KV_DIVIDER:
Expand Down Expand Up @@ -267,33 +284,42 @@ def _extract_markdown_cell(
raise NotebookCellParseError(
f"file not specified",
None,
source,
get_full_source(),
)

normalized_file_path = os.path.normpath(file_path).replace("\\", "/")
if normalized_file_path in embed_files:
lower_file_path = normalized_file_path.lower()

previous = embed_files.get(lower_file_path)
if previous is not None:
raise NotebookCellParseError(
f"file `{file_path}` specified multiple time",
None,
source,
f"file `{file_path}` is conflicting with `{previous.path}`",
get_full_source(),
)

content = "\n".join((
line
for line in iterator
))
content = "\n".join(iterator)

embed_files[normalized_file_path] = content
log(f"embed {normalized_file_path}: {len(content)} characters")
embed_files[lower_file_path] = EmbedFile(
file_path,
normalized_file_path,
content,
)

log(f"embed {lower_file_path}: {len(content)} characters")


def extract_cells(
cells: typing.List[typing.Any],
print: typing.Callable[[str], None] = print,
) -> typing.Tuple[str, typing.List[str]]:
packages: typing.Dict[str, Requirement] = {}
) -> typing.Tuple[
str,
typing.List[EmbedFile],
typing.List[Requirement],
]:
packages: typing.Dict[str, typing.Tuple[typing.List[str], typing.List[str]]] = {}
module: typing.List[str] = []
embed_files: typing.Dict[str, str] = {}
embed_files: typing.Dict[str, EmbedFile] = {}

for index, cell in enumerate(cells):
cell_id = cell["metadata"].get("id") or f"cell_{index}"
Expand Down Expand Up @@ -326,16 +352,16 @@ def log(message):

source_code = "\n".join(module)
requirements = [
{
"name": name,
"extras": requirement[0] if requirement is not None else None,
"specs": requirement[1] if requirement is not None else None,
}
Requirement(
name,
requirement[0] if requirement is not None else None,
requirement[1] if requirement is not None else None,
)
for name, requirement in packages.items()
]

return (
source_code,
embed_files,
list(embed_files.values()),
requirements,
)
28 changes: 7 additions & 21 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import typing
import unittest

from crunch.convert import (InconsistantLibraryVersionError,
NotebookCellParseError,
from crunch.convert import (EmbedFile, InconsistantLibraryVersionError,
NotebookCellParseError, Requirement,
RequirementVersionParseError, extract_cells)


Expand Down Expand Up @@ -94,21 +94,9 @@ def test_normal(self):
])

self.assertEqual([
{
"name": "hello",
"extras": None,
"specs": None,
},
{
"name": "world",
"extras": [],
"specs": ["==42"],
},
{
"name": "extras",
"extras": ["big"],
"specs": [">4.2"],
}
Requirement("hello", None, None),
Requirement("world", [], ["==42"]),
Requirement("extras", ["big"], [">4.2"]),
], requirements)

def test_inconsistant_version(self):
Expand Down Expand Up @@ -139,17 +127,15 @@ def test_normal(self):
) = extract_cells([
_cell("a", "markdown", [
"---",
"file: a",
"file: ./a.txt",
"---",
"# Hello World",
"from a embed markdown file",
])
])

self.assertEqual("", source_code)
self.assertEqual({
"a": "# Hello World\nfrom a embed markdown file"
}, embed_files)
self.assertEqual([EmbedFile("./a.txt", "a.txt", "# Hello World\nfrom a embed markdown file")], embed_files)
self.assertEqual([], requirements)

def test_root_not_a_dict(self):
Expand Down

0 comments on commit 3ae7eb7

Please sign in to comment.