diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml new file mode 100644 index 0000000..bcd7ce0 --- /dev/null +++ b/.github/workflows/cd.yml @@ -0,0 +1,103 @@ +name: CD + +on: + workflow_run: + workflows: ["CI"] + branches: [main] + tags: ["*"] + types: + - completed + +permissions: + contents: read + +jobs: + linux: + if: ${{ github.event.workflow_run.conclusion == 'success' }} + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + rust-toolchain: nightly-2024-02-03 + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: "true" + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + windows: + if: ${{ github.event.workflow_run.conclusion == 'success' }} + runs-on: windows-latest + strategy: + matrix: + target: [x64, x86] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + rust-toolchain: nightly-2024-02-03 + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: "true" + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + macos: + if: ${{ github.event.workflow_run.conclusion == 'success' }} + runs-on: macos-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + rust-toolchain: nightly-2024-02-03 + target: ${{ matrix.target }} + args: --release --out dist --find-interpreter + sccache: "true" + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + sdist: + if: ${{ github.event.workflow_run.conclusion == 'success' }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..590b75d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,58 @@ +name: CI + +on: + push: + branches: + - main + tags: + - "*" + pull_request: + workflow_dispatch: + +env: + CARGO_TERM_COLOR: always + +jobs: + test: + name: Unit tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-02-03 + - uses: Swatinem/rust-cache@v2 + - run: pip install -e '.[dev]' + - run: pytest -s tests + - run: cargo test --all + + lint_rust: + name: Lint Rust + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-02-03 + components: rustfmt, clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo fmt --all -- --check + - run: cargo clippy --all -- -D warnings + + lint_python: + name: Lint Python + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v4 + with: + python-version: "3.9" + - run: pip install black~=24.0 + - run: black --check python + - run: black --check tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fbe4648 --- /dev/null +++ b/.gitignore @@ -0,0 +1,74 @@ +/target +/venv +python/blended_dataset_loop/blended_dataset_loop + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..56e2e7f --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,377 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + +[[package]] +name = "argminmax" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "202108b46429b765ef483f8a24d5c46f48c14acfdacc086dd4ab6dddf6bcdbd2" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "blended_dataset_loop" +version = "0.1.0" +dependencies = [ + "argminmax", + "serde", + "serde_json", + "tqdm", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crossterm" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64e6c0fbe2c17357405f7c758c1ef960fce08bdfb2c03d88d2a18d7e09c4b67" +dependencies = [ + "bitflags", + "crossterm_winapi", + "libc", + "mio", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.150" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" + +[[package]] +name = "lock_api" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags", +] + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29ad2e15f37ec9a6cc544097b78a1ec90001e9f71b81338ca39f430adaca99af" +dependencies = [ + "libc", + "mio", + "signal-hook", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" + +[[package]] +name = "syn" +version = "2.0.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23e78b90f2fcf45d3e842032ce32e3f2d1545ba6636271dcbf24fa306d87be7a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tqdm" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c9e4aea46830eb68bbb272f637baa91318b94aee47d172c68e4c43495fd521f" +dependencies = [ + "anyhow", + "crossterm", + "lazy_static", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..93a7257 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "blended_dataset_loop" +version = "0.1.0" +edition = "2021" + +[lib] +name = "blended_dataset_loop" +crate-type = ["cdylib"] + +[dependencies] +argminmax = "0.6" +serde = {version="1.0", features=["derive"]} +serde_json = "1.0" +tqdm = "0.6" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0d959b9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,31 @@ +The following applies to all files in this repository, unless otherwise noted: + +Copyright (c) 2024 IPAI Aleph Alpha Research GmbH. All rights reserved. + +This project is licensed under the terms of the Open Aleph License 1.0, available at +https://github.com/Aleph-Alpha/.github/blob/main/oal.pdf + +--- +Excerpt from the license text: + +Subject to the terms and conditions of this License, the Licensor grants you a non-exclusive, worldwide, +non-transferable, non-sublicensable, and royalty-free limited right to use, copy, modify, distribute, make +otherwise publicly available, and reproduce the Works and Derivative Works under Licensor’s copyright, +for any Non-Commercial and Non-Administrative purpose. +You may not use, copy, modify, distribute, make otherwise publicly available, reproduce, or sublicense the +Works or Derivative Works except as expressly provided under and in accordance with this License. +Your rights granted under this License will automatically terminate if you fail to comply with any of the +terms of this License. + +EXCEPT FOR DAMAGES CAUSED BY INTENT OR FRAUDULENTLY CONCEALED +DEFECTS, AND EXCEPT FOR DAMAGES RESULTING FROM BREACH OF ANY +WARRANTY OR GUARANTEE EXPRESSLY GIVEN BY LICENSOR IN THE OPEN ALEPH LICENSE, +IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY +DAMAGES ARISING OUT OF THE OPEN ALEPH LICENSE OR THE USE OF THE WORK. ANY +MANDATORY STATUTORY LIABILITY UNDER APPLICABLE LAW REMAINS +UNAFFECTED. + +EXCEPT AS EXPRESSLY STATED IN THIS LICENSE OR REQUIRED BY APPLICABLE +LAW, THE WORKS ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES +OF ANY KIND INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES REGARDING +THE CONTENTS, ACCURACY, OR FITNESS FOR A PARTICULAR PURPOSE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cb5760d --- /dev/null +++ b/README.md @@ -0,0 +1,48 @@ +# Blended Dataset Loop + +This repository contains a simple loop to compute a balanced ordering of dataset indices to train on. +The resulting ordering ensures the data distribution is similar among batches. +The loop is implemented in Rust for performance reasons and can be consumed as part of a Python package. + +The package uses `cffi` (rather than e.g. `PyO3`) in order to be compatible with different Python versions. + +## Requirements + +* Conda (for Python) +* Cargo with nightly Rust + +## Setup + +Create a conda environment as follows: + +```sh +conda create -n blended_dataset_loop python=3.9 -y +conda activate blended_dataset_loop +``` + +Install Rust nightly + +```sh +rustup override set nightly-2024-02-03 +``` + +Install the Python dev-dependencies: + +```sh +pip install 'maturin[patchelf]' +pip install '.[dev]' +``` + +## Develop + +After changing the Rust code, run: + +```sh +maturin develop +``` + +or, for release mode + +```sh +maturin develop --release +``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bfe5e56 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,21 @@ +[build-system] +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[project] +name = "blended_dataset_loop" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] +dependencies = ["numpy", "cffi"] + +[project.optional-dependencies] +dev = ["pytest", "black~=24.0", "mypy"] + +[tool.maturin] +bindings = "cffi" +python-source = "python" diff --git a/python/blended_dataset_loop/__init__.py b/python/blended_dataset_loop/__init__.py new file mode 100644 index 0000000..7546af8 --- /dev/null +++ b/python/blended_dataset_loop/__init__.py @@ -0,0 +1,30 @@ +from pathlib import PurePath +import numpy as np +from typing import Union + +from .blended_dataset_loop import ffi, lib + + +def sample(number_to_sample: np.ndarray, cache_filename_stem: Union[PurePath, str]): + assert isinstance( + number_to_sample, np.ndarray + ), "expected `number_to_sample` to be a numpy array" + assert ( + len(number_to_sample.shape) == 1 + ), "expected `number_to_sample` to be one-dimensional" + assert isinstance(cache_filename_stem, PurePath) or isinstance( + cache_filename_stem, str + ), "expected `cache_filename_stem` to be a PurePath or str" + + number_to_sample_len = len(number_to_sample) + number_to_sample_buf = ffi.from_buffer(number_to_sample) + number_to_sample_c = ffi.cast("uint64_t *", number_to_sample_buf) + path_utf8 = str(cache_filename_stem).encode("utf-8") + cache_filename_stem_c = ffi.cast("uint8_t *", ffi.from_buffer(path_utf8)) + + lib.sample( + number_to_sample_c, number_to_sample_len, cache_filename_stem_c, len(path_utf8) + ) + + +__all__ = ["sample"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b660c85 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,103 @@ +use std::f64; +use std::fs::File; +use std::io::{BufWriter, Write}; + +use argminmax::ArgMinMax; +use serde::Serialize; +use tqdm::tqdm; + +#[derive(Serialize)] +struct Result { + number_to_sample: Vec, + result: Vec>, +} + +#[derive(serde::Serialize)] +struct Metadata { + dtype: String, + total_count: u64, + shape: [usize; 2], +} + +/// SAFETY: The following invariants must be ensured from the Python wrapper: +/// * `number_to_sample` must point to a buffer of length `n_datasets`. +/// * `cache_filename_stem` must point to a buffer of length `cache_filename_stem_len` containing valid UTF-8 data +#[no_mangle] +extern "C" fn sample( + number_to_sample: *const u64, + n_datasets: usize, + cache_filename_stem: *const u8, + cache_filename_stem_len: usize, +) { + let number_to_sample = unsafe { std::slice::from_raw_parts(number_to_sample, n_datasets) }; + let cache_filename_stem_slice = + unsafe { std::slice::from_raw_parts(cache_filename_stem, cache_filename_stem_len) }; + let cache_filename_stem = unsafe { std::str::from_utf8_unchecked(cache_filename_stem_slice) }; + sample_impl(number_to_sample, cache_filename_stem); +} + +fn sample_impl(number_to_sample: &[u64], cache_filename_stem: &str) { + let input_json_filepath = format!("{cache_filename_stem}.input.json"); + let input_json_writer = BufWriter::new( + File::create(input_json_filepath) + .expect("Failed to create blended data set index input.json file."), + ); + serde_json::to_writer(input_json_writer, number_to_sample) + .expect("Failed to write to blended data set index input.json file."); + + let mut number_sampled: Vec = vec![0; number_to_sample.len()]; + let mut proportion_sampled: Vec = vec![0.0; number_to_sample.len()]; + + let total_count = number_to_sample.iter().copied().sum::() as usize; + + let bin_filename = format!("{cache_filename_stem}.bin"); + let mut result_writer = BufWriter::new( + File::create(bin_filename).expect("Failed to create blended data set index bin file."), + ); + + for _ in tqdm(0..total_count) { + // find smallest represenation + let (dataset_index, _argmax) = proportion_sampled.argminmax(); + + // add representation from dataset + let sample_index = number_sampled[dataset_index]; + + // update proportions + number_sampled[dataset_index] += 1; + proportion_sampled[dataset_index] = + (number_sampled[dataset_index] as f64) / number_to_sample[dataset_index] as f64; + + result_writer + .write_all(&(dataset_index as u64).to_le_bytes()) + .expect("Failed to write blended data set index bin file."); + result_writer + .write_all(&sample_index.to_le_bytes()) + .expect("Failed to write blended data set index bin file."); + } + + assert!(total_count as u64 == number_sampled.iter().copied().sum::()); + + number_to_sample + .iter() + .zip(number_sampled.iter()) + .for_each(|(a, b)| assert_eq!(a, b)); + + result_writer + .flush() + .expect("Failed to flush blended data set index bin file."); + + // write metadata + let meta_filename = format!("{cache_filename_stem}.meta.json"); + let meta_writer = BufWriter::new( + File::create(meta_filename).expect("Failed to create blended data set meta.json file"), + ); + serde_json::to_writer( + meta_writer, + &Metadata { + total_count: total_count as u64, + dtype: "uint64".to_string(), + shape: [total_count, 2], + }, + ) + .expect("Failed to write blended data set meta.json file"); +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_blended_dataset_loop.py b/tests/test_blended_dataset_loop.py new file mode 100644 index 0000000..85d7911 --- /dev/null +++ b/tests/test_blended_dataset_loop.py @@ -0,0 +1,51 @@ +import numpy as np +import blended_dataset_loop +import json + +from pathlib import Path + + +def expect_input(cache_filename_stem: Path, expected): + with open(f"{cache_filename_stem}.input.json") as file: + actual = json.load(file) + assert actual == expected + + +def expect_meta(cache_filename_stem: Path, expected): + with open(f"{cache_filename_stem}.meta.json") as file: + actual = json.load(file) + assert actual == expected + + +def expect_bin(cache_filename_stem: Path, expected): + with open(f"{cache_filename_stem}.meta.json") as file: + meta = json.load(file) + contents = np.fromfile(f"{cache_filename_stem}.bin", dtype=meta["dtype"]).reshape( + tuple(meta["shape"]) + ) + assert np.array_equal(contents, expected) + + +def test_simple(tmp_path: Path): + cache_filename_stem = tmp_path / "simple" + number_to_sample = np.array([1, 2, 3, 4], dtype="uint64") + blended_dataset_loop.sample(number_to_sample, str(cache_filename_stem)) + expect_input(cache_filename_stem, [1, 2, 3, 4]) + expect_meta( + cache_filename_stem, {"dtype": "uint64", "total_count": 10, "shape": [10, 2]} + ) + expect_bin( + cache_filename_stem, + [ + [0, 0], + [1, 0], + [2, 0], + [3, 0], + [3, 1], + [2, 1], + [1, 1], + [3, 2], + [2, 2], + [3, 3], + ], + )