Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytest release marker #2114

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ jobs:
needs: build-and-push
runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"]
if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest'
env:
PYTEST_FLAGS: ${{ github.ref == 'refs/heads/main' && '--release' || '' }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand All @@ -180,4 +182,4 @@ jobs:
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
echo $DOCKER_IMAGE
pytest -s -vv integration-tests
pytest -s -vv integration-tests ${PYTEST_FLAGS}
20 changes: 20 additions & 0 deletions integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@
DOCKER_DEVICES = os.getenv("DOCKER_DEVICES")


def pytest_addoption(parser):
parser.addoption(
"--release", action="store_true", default=False, help="run release tests"
)


def pytest_configure(config):
config.addinivalue_line("markers", "release: mark test as a release-only test")


def pytest_collection_modifyitems(config, items):
if config.getoption("--release"):
# --release given in cli: do not skip release tests
return
skip_release = pytest.mark.skip(reason="need --release option to run")
for item in items:
if "release" in item.keywords:
item.add_marker(skip_release)


class ResponseComparator(JSONSnapshotExtension):
rtol = 0.2
ignore_logprob = False
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_bloom_560m.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def bloom_560(bloom_560_handle):
return bloom_560_handle.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m(bloom_560, response_snapshot):
response = await bloom_560.generate(
Expand All @@ -27,6 +28,7 @@ async def test_bloom_560m(bloom_560, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m_all_params(bloom_560, response_snapshot):
response = await bloom_560.generate(
Expand All @@ -49,6 +51,7 @@ async def test_bloom_560m_all_params(bloom_560, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m_load(bloom_560, generate_load, response_snapshot):
responses = await generate_load(
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_bloom_560m_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def bloom_560m_sharded(bloom_560m_sharded_handle):
return bloom_560m_sharded_handle.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
response = await bloom_560m_sharded.generate(
Expand All @@ -27,6 +28,7 @@ async def test_bloom_560m_sharded(bloom_560m_sharded, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m_sharded_load(
bloom_560m_sharded, generate_load, response_snapshot
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_completion_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def flash_llama_completion(flash_llama_completion_handle):
# method for it. Instead, we use the `requests` library to make the HTTP request directly.


@pytest.mark.release
def test_flash_llama_completion_single_prompt(
flash_llama_completion, response_snapshot
):
Expand All @@ -46,6 +47,7 @@ def test_flash_llama_completion_single_prompt(
assert response == response_snapshot


@pytest.mark.release
def test_flash_llama_completion_many_prompts(flash_llama_completion, response_snapshot):
response = requests.post(
f"{flash_llama_completion.base_url}/v1/completions",
Expand All @@ -68,6 +70,7 @@ def test_flash_llama_completion_many_prompts(flash_llama_completion, response_sn
assert response == response_snapshot


@pytest.mark.release
async def test_flash_llama_completion_many_prompts_stream(
flash_llama_completion, response_snapshot
):
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
return flash_llama_awq_handle.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
Expand All @@ -31,6 +32,7 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
response = await flash_llama_awq.generate(
Expand All @@ -52,6 +54,7 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load(
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_awq_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
return flash_llama_awq_handle_sharded.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
response = await flash_llama_awq_sharded.generate(
Expand All @@ -31,6 +32,7 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_llama_awq_load_sharded(
flash_llama_awq_sharded, generate_load, response_snapshot
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_falcon(flash_falcon_handle):
return flash_falcon_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon(flash_falcon, response_snapshot):
Expand All @@ -26,6 +27,7 @@ async def test_flash_falcon(flash_falcon, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
Expand All @@ -49,6 +51,7 @@ async def test_flash_falcon_all_params(flash_falcon, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_falcon_load(flash_falcon, generate_load, response_snapshot):
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_gemma(flash_gemma_handle):
return flash_gemma_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma(flash_gemma, response_snapshot):
Expand All @@ -24,6 +25,7 @@ async def test_flash_gemma(flash_gemma, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
Expand All @@ -47,6 +49,7 @@ async def test_flash_gemma_all_params(flash_gemma, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_load(flash_gemma, generate_load, response_snapshot):
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_gemma_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_gemma_gptq(flash_gemma_gptq_handle):
return flash_gemma_gptq_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapshot):
Expand All @@ -24,6 +25,7 @@ async def test_flash_gemma_gptq(flash_gemma_gptq, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_all_params(
Expand All @@ -49,6 +51,7 @@ async def test_flash_gemma_gptq_all_params(
assert response == ignore_logprob_response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_gemma_gptq_load(
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_gpt2(flash_gpt2_handle):
return flash_gpt2_handle.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_gpt2(flash_gpt2, response_snapshot):
response = await flash_gpt2.generate(
Expand All @@ -25,6 +26,7 @@ async def test_flash_gpt2(flash_gpt2, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_gpt2_load(flash_gpt2, generate_load, response_snapshot):
responses = await generate_load(
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_exl2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async def flash_llama_exl2(flash_llama_exl2_handle):
return flash_llama_exl2_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapshot):
Expand All @@ -32,6 +33,7 @@ async def test_flash_llama_exl2(flash_llama_exl2, ignore_logprob_response_snapsh
assert response == ignore_logprob_response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_all_params(
Expand All @@ -58,6 +60,7 @@ async def test_flash_llama_exl2_all_params(
assert response == ignore_logprob_response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_exl2_load(
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_llama_gptq(flash_llama_gptq_handle):
return flash_llama_gptq_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
Expand All @@ -24,6 +25,7 @@ async def test_flash_llama_gptq(flash_llama_gptq, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
Expand All @@ -46,6 +48,7 @@ async def test_flash_llama_gptq_all_params(flash_llama_gptq, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_load(
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def flash_llama_gptq_marlin(flash_llama_gptq_marlin_handle):
return flash_llama_gptq_marlin_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapshot):
Expand All @@ -26,6 +27,7 @@ async def test_flash_llama_gptq_marlin(flash_llama_gptq_marlin, response_snapsho
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_all_params(
Expand All @@ -50,6 +52,7 @@ async def test_flash_llama_gptq_marlin_all_params(
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_gptq_marlin_load(
Expand Down
3 changes: 3 additions & 0 deletions integration-tests/models/test_flash_llama_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def flash_llama_marlin(flash_llama_marlin_handle):
return flash_llama_marlin_handle.client


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
Expand All @@ -26,6 +27,7 @@ async def test_flash_llama_marlin(flash_llama_marlin, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapshot):
Expand All @@ -48,6 +50,7 @@ async def test_flash_llama_marlin_all_params(flash_llama_marlin, response_snapsh
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_marlin_load(
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_neox(flash_neox_handle):
return flash_neox_handle.client


@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox(flash_neox, response_snapshot):
Expand All @@ -26,6 +27,7 @@ async def test_flash_neox(flash_neox, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.skip
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_neox_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ async def flash_neox_sharded(flash_neox_sharded_handle):
return flash_neox_sharded_handle.client


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_neox(flash_neox_sharded, response_snapshot):
response = await flash_neox_sharded.generate(
Expand All @@ -25,6 +26,7 @@ async def test_flash_neox(flash_neox_sharded, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
async def test_flash_neox_load(flash_neox_sharded, generate_load, response_snapshot):
responses = await generate_load(
Expand Down
2 changes: 2 additions & 0 deletions integration-tests/models/test_flash_pali_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def get_cow_beach():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
Expand All @@ -45,6 +46,7 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
assert response == response_snapshot


@pytest.mark.release
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
Expand Down
Loading
Loading