Skip to content

Commit

Permalink
Merge branch 'mainline' into aditya/fix-folder-to-ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
adityabharadwaj198 committed Nov 21, 2024
2 parents de641ac + b3da5d4 commit ac410ae
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 47 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/cpu_local_marqo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ on:
- releases/*
paths-ignore:
- '**.md'
pull_request_target:
pull_request:
branches:
- mainline
- releases/*
Expand All @@ -44,7 +44,7 @@ permissions:
contents: read

concurrency:
group: cpu-local-api-tests-${{ github.head_ref || github.ref }}
group: cpu-local-api-tests-${{ github.ref }}
cancel-in-progress: true

jobs:
Expand Down Expand Up @@ -94,9 +94,6 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
# if triggered by a pull_request_target event, we should use the merge ref of the PR
# if triggered by a push event, github.ref points to the head of the source branch
ref: ${{ github.event_name == 'pull_request_target' && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.ref }}

- name: Set up Python 3.8
uses: actions/setup-python@v3
Expand Down
11 changes: 6 additions & 5 deletions .github/workflows/largemodel_unit_test_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ on:
- releases/*
paths-ignore:
- '**.md'
pull_request_target:
pull_request:
branches:
- mainline
- releases/*
paths-ignore:
- '**.md'

concurrency:
group: large-model-unit-tests-${{ github.head_ref || github.ref }}
group: large-model-unit-tests-${{ github.ref }}
cancel-in-progress: true

permissions:
Expand Down Expand Up @@ -69,9 +69,6 @@ jobs:
with:
fetch-depth: 0
path: marqo
# if triggered by a pull_request_target event, we should use the merge ref of the PR
# if triggered by a push event, github.ref points to the head of the source branch
ref: ${{ github.event_name == 'pull_request_target' && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.ref }}

- name: Set up Python 3.8
uses: actions/setup-python@v3
Expand All @@ -92,6 +89,10 @@ jobs:
pip install -r marqo/requirements.dev.txt
pip install pytest==7.4.0
- name: Download nltk data
run: |
python -m nltk.downloader punkt_tab
- name: Build Vespa
run: |
systemctl stop unattended-upgrades
Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/test_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ jobs:
fetch-depth: 0
path: marqo

- name: Checkout marqo-base for requirements
uses: actions/checkout@v3
with:
repository: marqo-ai/marqo-base
path: marqo-base
fetch-depth: 0

- name: Set up Python 3.8
uses: actions/setup-python@v3
with:
Expand All @@ -28,8 +35,8 @@ jobs:

- name: Install dependencies
run: |
pip install -r marqo-base/requirements/amd64-gpu-requirements.txt
pip install -r marqo/requirements.dev.txt
pip install pytest==7.4.0
- name: Run Documentation Tests
run: |
Expand Down
11 changes: 6 additions & 5 deletions .github/workflows/unit_test_200gb_CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ on:
branches:
- mainline
- releases/*
pull_request_target:
pull_request:
branches:
- mainline
- releases/*

concurrency:
group: unit-tests-${{ github.head_ref || github.ref }}
group: unit-tests-${{ github.ref }}
cancel-in-progress: true

permissions:
Expand Down Expand Up @@ -65,9 +65,6 @@ jobs:
with:
fetch-depth: 0
path: marqo
# if triggered by a pull_request_target event, we should use the merge ref of the PR
# if triggered by a push event, github.ref points to the head of the source branch
ref: ${{ github.event_name == 'pull_request_target' && format('refs/pull/{0}/merge', github.event.pull_request.number) || github.ref }}

- name: Set up Python 3.8
uses: actions/setup-python@v3
Expand All @@ -87,6 +84,10 @@ jobs:
# override base requirements with marqo requirements, if needed:
pip install -r marqo/requirements.dev.txt
- name: Download nltk data
run: |
python -m nltk.downloader punkt_tab
- name: Build Vespa
run: |
systemctl stop unattended-upgrades
Expand Down
7 changes: 7 additions & 0 deletions src/marqo/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, message: str):
self.message = message


class StartupSanityCheckError(MarqoError):
code = "startup_sanity_check_error"

def __init__(self, message: str):
self.message = message


# TODO: DELETE
class MarqoApiError(MarqoError):
"""Error sent by Marqo API"""
Expand Down
18 changes: 3 additions & 15 deletions src/marqo/s2_inference/processing/text.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Any, Dict, List, Optional, Union
from functools import partial
from types import FunctionType
from typing import List

from functools import partial
from more_itertools import windowed

# sent_tokenize and word_tokenize requires the punkt_tab dataset
from nltk.tokenize import sent_tokenize, word_tokenize
import nltk


def _splitting_functions(split_by: str, language: str='english') -> FunctionType:
Expand All @@ -25,17 +24,6 @@ def _splitting_functions(split_by: str, language: str='english') -> FunctionType
if not isinstance(split_by, str):
raise TypeError(f"expected str received {type(split_by)}")

try:
nltk.data.find("tokenizers/punkt")
except LookupError:
nltk.download("punkt")

# Punkt_tab needs to be downloaded after NLTK 3.8 and later
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")

MAPPING = {
'character':list,
'word': partial(word_tokenize, language=language),
Expand Down
37 changes: 28 additions & 9 deletions src/marqo/tensor_search/on_start_script.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
import json
import os
import time
from threading import Lock

import nltk
import torch

from threading import Lock
from PIL import Image

from marqo import config, marqo_docs, version
from marqo import config, version
from marqo import marqo_docs
from marqo.api import exceptions
from marqo.connections import redis_driver
from marqo.s2_inference.s2_inference import vectorise
from marqo.s2_inference.processing.image import chunk_image
from marqo.s2_inference.constants import PATCH_MODELS
from marqo.s2_inference.processing.image import chunk_image
from marqo.s2_inference.s2_inference import vectorise
# we need to import backend before index_meta_cache to prevent circular import error:
from marqo.tensor_search import constants
from marqo.tensor_search import index_meta_cache, utils
from marqo.tensor_search.enums import EnvVars
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo import marqo_docs



logger = get_logger(__name__)

Expand All @@ -33,6 +31,7 @@ def on_start(config: config.Config):
CUDAAvailable(),
SetBestAvailableDevice(),
CacheModels(),
CheckNLTKTokenizers(),
InitializeRedis("localhost", 6379),
CachePatchModels(),
DownloadFinishText(),
Expand Down Expand Up @@ -259,7 +258,27 @@ def run(self):
for message in messages:
self.logger.info(message)
self.logger.info("completed prewarming patch models")



class CheckNLTKTokenizers:
"""Check if NLTK tokenizers are available, if not, download them.
NLTK tokenizers are included in the base-image, we do a sanity check to ensure they are available.
"""
def run(self):
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
logger.info("NLTK punkt_tab tokenizer not found. Downloading...")
nltk.download("punkt_tab")

try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError as e:
raise exceptions.StartupSanityCheckError(
f"Marqo failed to download and download NLTK tokenizers. Original error: {e}"
) from e


def _preload_model(model, content, device):
"""
Expand Down
20 changes: 13 additions & 7 deletions tests/tensor_search/test_on_start_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from unittest import mock

from marqo.api import exceptions, configs
from marqo.api.exceptions import StartupSanityCheckError
from marqo.tensor_search import enums
from marqo.tensor_search import on_start_script
from tests.marqo_test import MarqoTestCase
Expand Down Expand Up @@ -277,10 +278,15 @@ def test_boostrap_failure_should_raise_error(self, mock_config):

self.assertTrue('some error' in str(context.exception))








def test_missing_punkt_downloaded(self):
"""A test to ensure that the script will attempt to download the punkt_tab
tokenizer if it is not found"""
with mock.patch("marqo.tensor_search.on_start_script.nltk.data.find") as mock_find, \
mock.patch("marqo.tensor_search.on_start_script.nltk.download") as mock_nltk_download:
# Mock find to always succeed
mock_find.side_effect = LookupError()

checker = on_start_script.CheckNLTKTokenizers()
with self.assertRaises(StartupSanityCheckError):
checker.run()
mock_nltk_download.assert_any_call("punkt_tab")

0 comments on commit ac410ae

Please sign in to comment.