Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need the ability to blacklisting/whitelisting of characters #988

Open
Tracked by #1752
SlappyAUS opened this issue Jul 16, 2022 Discussed in #888 · 6 comments
Open
Tracked by #1752

Need the ability to blacklisting/whitelisting of characters #988

SlappyAUS opened this issue Jul 16, 2022 Discussed in #888 · 6 comments
Assignees
Labels
duplicate This issue or pull request already exists framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: new feature New feature
Milestone

Comments

@SlappyAUS
Copy link

Discussed in #888

Originally posted by Xargonus April 7, 2022
Hello, is there currently a way to blacklist or whitelist characters used by the text recognition model?

@frgfm frgfm added the duplicate This issue or pull request already exists label Jul 20, 2022
@frgfm
Copy link
Collaborator

frgfm commented Jul 20, 2022

Thanks for opening the issue, let's first focus the discussions in #888 to avoid duplicates :)

@dchaplinsky
Copy link

That's a great one too. Again, callback idea might shine here, since you can boost/deboost some characters, not completely blacklist them. For example, I often see 1 recognized as l or Q as 0/O, or : as 0. I'd like to avoid complete blacklist, but if I can, for example, boost and prioritize digits over letters or : over . it can solve the problem.

@felixdittrich92 felixdittrich92 added this to the 1.0.0 milestone Feb 9, 2024
@felixdittrich92 felixdittrich92 self-assigned this Feb 9, 2024
@felixdittrich92 felixdittrich92 added module: models Related to doctr.models framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text recognition Related to the task of text recognition type: new feature New feature labels Feb 9, 2024
@felixdittrich92
Copy link
Contributor

felixdittrich92 commented Feb 13, 2024

@odulcy-mindee @dchaplinsky

A short example (with master) how we could reach this:

class MASTERPostProcessor(_MASTERPostProcessor):
    """Post processor for MASTER architectures"""

    def __call__(
        self,
        logits: torch.Tensor,
        blacklist: List[str] = ["a", "r", "g"]  # DUMMY :)
    ) -> List[Tuple[str, float]]:
        batch_size, max_seq_length, vocab_size = logits.size()

        # Find indices of blacklisted characters
        blacklist_indices = [self._embedding.index(char) for char in blacklist if char in self._embedding]

        # Adjust logits for blacklisted characters
        for i in range(batch_size):
            for j in range(max_seq_length):
                for idx in blacklist_indices:
                    logits[i, j, idx] = float('-inf')  # Set probability to negative infinity to exclude it

        # Compute pred with argmax for attention models
        out_idxs = logits.argmax(-1)
        # N x L
        probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1)
        # Take the minimum confidence of the sequence
        probs = probs.min(dim=1).values.detach().cpu()

        # Manual decoding
        word_values = [
            "".join(self._embedding[idx] for idx in encoded_seq).split("<eos>")[0]
            for encoded_seq in out_idxs.cpu().numpy()
        ]

        return list(zip(word_values, probs.numpy().clip(0, 1).tolist()))

This would take the next "char" with the highest prob @SlappyAUS is this what you have had in mind ?
Or full removal ?

For example:

blacklist = ["e", "l"]
normal_out = "hello"
blacklisted_out = "ho"

@dchaplinsky
Copy link

What if we make it a bit broader and instead provide an option to multiply the logits to some weight.
For example, l and 1 and I can often be wrongly identified, and as soon as I expect mostly digits, I can upvote "digits" to break the almost-tie situations without killing all the rest of characters.

@felixdittrich92
Copy link
Contributor

@dchaplinsky Mhhh this would need any mapping to "close" characters ? 🤔

@frgfm
Copy link
Collaborator

frgfm commented Feb 23, 2024

Hey everyone 👋

The use case here is to help out the text recognition part when you have more info about a subvocab, so now we need to assess whether that's worth addressing (I think it would be useful), what would be the API, and at which step this should take place.

My two cents:

  1. This will be quite useful to better leverage wide vocab model checkpoints without retraining them. We should work on this
  2. For the API, the issue with passing whitelist/blacklist of characters is that you need both and then you'll need an arbitrary order of priority. With minimal snippet, we can easily convert that to a weight vector which will be used by the model.
  3. For simplicity, I think this should be passed to the model (or post processor) and made accessible as an instance attribute as a Tensor. At inference time, this should be used in the call of the post processor either here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L224 if at model level, or here https://github.com/mindee/doctr/blob/main/doctr/models/recognition/crnn/pytorch.py#L75 if at postprocessor level

Here is my suggested design for blacklist:

import torch
from doctr.models import crnn_vgg16_bn

blacklisted_chars = {str(num) for num in range(10)}
# Set the mask
vocab_mask = torch.tensor((0 if char in blacklisted_chars else 1 for char in vocab), dtype=torch.float32)
model = crnn_vgg16_bn(pretrained=True, vocab_mask=vocab_mask)

input_tensor = torch.rand(1, 3, 32, 128)
out = model(input_tensor)

and whitelist:

whitelisted_chars = {str(num) for num in range(10)}
vocab_mask = torch.tensor((1 if char in whitelisted_chars else 0 for char in vocab), dtype=torch.float32)

What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
duplicate This issue or pull request already exists framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

No branches or pull requests

4 participants