Skip to content

Commit

Permalink
update to use pytest-benchmark
Browse files Browse the repository at this point in the history
* publish plot/graph for `CatchAll` benchmark tests
* also refactor into `requirements-bench.txt`
  • Loading branch information
rnag committed Dec 13, 2024
1 parent 6477ad3 commit dba8877
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 64 deletions.
Binary file added benchmarks/catch_all.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 35 additions & 52 deletions benchmarks/catch_all.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import logging
from dataclasses import dataclass
from timeit import timeit
from typing import Any

import pytest

from dataclasses_json import (dataclass_json,
Undefined,
CatchAll as CatchAllDJ)

from dataclass_wizard import (JSONWizard,
CatchAll as CatchAllWizard)
from dataclasses_json import (dataclass_json, Undefined, CatchAll as CatchAllDJ)
from dataclass_wizard import (JSONWizard, CatchAll as CatchAllWizard)


log = logging.getLogger(__name__)
Expand All @@ -37,6 +32,7 @@ class _(JSONWizard.Meta):
unknown_things: CatchAllWizard


# Fixtures for test data
@pytest.fixture(scope='session')
def data():
return {"endpoint": "some_api_endpoint",
Expand All @@ -50,72 +46,59 @@ def data_no_extras():
"data": {"foo": 1, "bar": "2"}}


def test_load(data, n):
"""
[ RESULTS ON MAC OS X ]
# Benchmark for deserialization (from_dict)
@pytest.mark.benchmark(group="deserialization")
def test_deserialize_wizard(benchmark, data):
benchmark(lambda: DontCareAPIDumpWizard.from_dict(data))

benchmarks.catch_all.catch_all - [INFO] dataclass-wizard 0.060889
benchmarks.catch_all.catch_all - [INFO] dataclasses-json 11.469157

"""
g = globals().copy()
g.update(locals())
@pytest.mark.benchmark(group="deserialization")
def test_deserialize_json(benchmark, data):
benchmark(lambda: DontCareAPIDumpDJ.from_dict(data))

log.info('dataclass-wizard %f',
timeit('DontCareAPIDumpWizard.from_dict(data)', globals=g, number=n))
log.info('dataclasses-json %f',
timeit('DontCareAPIDumpDJ.from_dict(data)', globals=g, number=n))

dump1 = DontCareAPIDumpDJ.from_dict(data) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})
dump2 = DontCareAPIDumpWizard.from_dict(data) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})
# Benchmark for deserialization with no extra data
@pytest.mark.benchmark(group="deserialization_no_extra_data")
def test_deserialize_wizard_no_extras(benchmark, data_no_extras):
benchmark(lambda: DontCareAPIDumpWizard.from_dict(data_no_extras))

assert dump1.endpoint == dump2.endpoint
assert dump1.data == dump2.data
assert dump1.unknown_things == dump2.unknown_things

@pytest.mark.benchmark(group="deserialization_no_extra_data")
def test_deserialize_json_no_extras(benchmark, data_no_extras):
benchmark(lambda: DontCareAPIDumpDJ.from_dict(data_no_extras))

def test_load_with_no_extra_data(data_no_extras, n):
"""
[ RESULTS ON MAC OS X ]

benchmarks.catch_all.catch_all - [INFO] dataclass-wizard 0.045790
benchmarks.catch_all.catch_all - [INFO] dataclasses-json 11.031206
# Benchmark for serialization (to_dict)
@pytest.mark.benchmark(group="serialization")
def test_serialize_wizard(benchmark, data):
dump1 = DontCareAPIDumpWizard.from_dict(data)
benchmark(lambda: dump1.to_dict())

"""
g = globals().copy()
g.update(locals())

log.info('dataclass-wizard %f',
timeit('DontCareAPIDumpWizard.from_dict(data_no_extras)', globals=g, number=n))
log.info('dataclasses-json %f',
timeit('DontCareAPIDumpDJ.from_dict(data_no_extras)', globals=g, number=n))
@pytest.mark.benchmark(group="serialization")
def test_serialize_json(benchmark, data):
dump2 = DontCareAPIDumpDJ.from_dict(data)
benchmark(lambda: dump2.to_dict())


def test_validate(data, data_no_extras):
dump1 = DontCareAPIDumpDJ.from_dict(data_no_extras) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})
dump2 = DontCareAPIDumpWizard.from_dict(data_no_extras) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})

assert dump1.endpoint == dump2.endpoint
assert dump1.data == dump2.data
assert dump1.unknown_things == dump2.unknown_things == {}

expected = {'endpoint': 'some_api_endpoint', 'data': {'foo': 1, 'bar': '2'}}

def test_dump(data):
"""
[ RESULTS ON MAC OS X ]
benchmarks.catch_all.catch_all - [INFO] dataclass-wizard 0.317555
benchmarks.catch_all.catch_all - [INFO] dataclasses-json 3.970232
"""
dump1 = DontCareAPIDumpWizard.from_dict(data)
dump2 = DontCareAPIDumpDJ.from_dict(data)
assert dump1.to_dict() == dump2.to_dict() == expected

g = globals().copy()
g.update(locals())
dump1 = DontCareAPIDumpDJ.from_dict(data) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})
dump2 = DontCareAPIDumpWizard.from_dict(data) # DontCareAPIDump(endpoint='some_api_endpoint', data={'foo': 1, 'bar': '2'})

log.info('dataclass-wizard %f',
timeit('dump1.to_dict()', globals=g, number=n))
log.info('dataclasses-json %f',
timeit('dump2.to_dict()', globals=g, number=n))
assert dump1.endpoint == dump2.endpoint
assert dump1.data == dump2.data
assert dump1.unknown_things == dump2.unknown_things

expected = {'endpoint': 'some_api_endpoint', 'data': {'foo': 1, 'bar': '2'}, 'undefined_field_name': [1, 2, 3]}

Expand Down
4 changes: 0 additions & 4 deletions dataclass_wizard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@
'IS_NOT',
'IS_TRUTHY',
'IS_FALSY',
# V1
'Alias',
'AliasPath',
]

import logging
Expand All @@ -130,7 +127,6 @@
Pattern, DatePattern, TimePattern, DateTimePattern,
CatchAll, SkipIf, SkipIfNone,
EQ, NE, LT, LE, GT, GE, IS, IS_NOT, IS_TRUTHY, IS_FALSY)
from .v1.models import Alias, AliasPath
from .environ.wizard import EnvWizard
from .property_wizard import property_wizard
from .serial_json import JSONWizard, JSONPyWizard, JSONSerializable
Expand Down
4 changes: 4 additions & 0 deletions dataclass_wizard/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
__all__ = ['Alias',
'AliasPath']

from .models import Alias, AliasPath
9 changes: 9 additions & 0 deletions requirements-bench.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Benchmark tests
matplotlib
pytest-benchmark[histogram]
dataclasses-json==0.6.7
jsons==1.6.3
dataclass-factory==2.16 # pyup: ignore
dacite==1.8.1
mashumaro==3.15
pydantic==2.10.2
7 changes: 0 additions & 7 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,3 @@ pytest==8.3.3
pytest-mock>=3.6.1
pytest-cov==6.0.0
# pytest-runner==5.3.1
# Benchmark tests
dataclasses-json==0.6.7
jsons==1.6.3
dataclass-factory==2.16 # pyup: ignore
dacite==1.8.1
mashumaro==3.15
pydantic==2.10.2
94 changes: 94 additions & 0 deletions run_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import glob
import json
import os
import shutil
import subprocess
import matplotlib.pyplot as plt


def run_benchmarks():
# Ensure the `.benchmarks` folder exists
os.makedirs(".benchmarks", exist_ok=True)

# Run pytest benchmarks and save results
print("Running benchmarks...")
result = subprocess.run(
["pytest", "benchmarks/catch_all.py", "--benchmark-save=benchmark_results"],
capture_output=True,
text=True
)
print(result.stdout)


def load_benchmark_results(file_path):
"""Load the benchmark results from the provided JSON file."""
with open(file_path, "r") as f:
return json.load(f)


def plot_relative_performance(results):
"""Plot relative performance for different benchmark groups."""
benchmarks = results["benchmarks"]

# Extract and format data
names = []
ops = []
for bm in benchmarks:
group = bm.get("group", "")
library = "dataclass-wizard" if "wizard" in bm["name"] else "dataclasses-json"
formatted_name = f"{group} ({library})"
names.append(formatted_name)
ops.append(bm["stats"]["ops"])

# Calculate relative performance (ratio of each ops to the slowest ops)
baseline = min(ops)
relative_performance = [op / baseline for op in ops]

# Plot bar chart
plt.figure(figsize=(10, 6))
bars = plt.barh(names, relative_performance, color="skyblue")
plt.xlabel("Performance Relative to Slowest (times faster)")
plt.title("Catch All: Relative Performance of dataclass-wizard vs dataclasses-json")
plt.tight_layout()

# Add data labels to the bars
for bar, rel_perf in zip(bars, relative_performance):
plt.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height() / 2,
f"{rel_perf:.1f}x", va="center")

# Save and display the plot
plt.savefig("catch_all.png")
plt.show()


def find_latest_benchmark_file():
"""Find the most recent benchmark result file."""
benchmark_dir = ".benchmarks"
pattern = os.path.join(benchmark_dir, "**", "*.json")
files = glob.glob(pattern, recursive=True)
if not files:
raise FileNotFoundError("No benchmark files found.")
latest_file = max(files, key=os.path.getctime) # Find the most recently created file
return latest_file


if __name__ == "__main__":
# Step 1: Run benchmarks
run_benchmarks()

# Step 2: Find the latest benchmark results file
benchmark_file = find_latest_benchmark_file()
print(f"Latest benchmark file: {benchmark_file}")

# Step 3: Load the benchmark results
if os.path.exists(benchmark_file):
results = load_benchmark_results(benchmark_file)

# Step 4: Plot results
plot_relative_performance(results)

else:
print(f"Benchmark file not found: {benchmark_file}")

# Step 5: Move the generated image to docs folder for easy access
shutil.copy("relative_performance.png", "docs/")
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
else: # Running on CI
test_requirements = []

if (requires_bench_file := here / 'requirements-bench.txt').exists():
with requires_bench_file.open() as requires_bench_txt:
bench_requirements = [str(req) for req in parse_requirements(requires_bench_txt)]
else: # Running on CI
bench_requirements = []

# extras_require = {
# 'dotenv': ['python-dotenv>=0.19.0'],
# }
Expand Down Expand Up @@ -107,7 +113,7 @@
'tomli-w>=1,<2'
],
'yaml': ['PyYAML>=6,<7'],
'dev': dev_requires + doc_requires + test_requirements,
'dev': dev_requires + doc_requires + test_requirements + bench_requirements,
},
zip_safe=False
)

0 comments on commit dba8877

Please sign in to comment.