Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RDKit SMILES parsing #23

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions safe/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self,
slicer: Optional[Union[str, List[str], Callable]] = "brics",
require_hs: Optional[bool] = None,
use_original_opener_for_attach: bool = True,
):
"""Constructor for the SAFE converter

Expand All @@ -72,6 +73,8 @@ def __init__(
or a custom callable that returns the bond ids that can be sliced.
require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
`attach` slicer requires adding hydrogens.
use_original_opener_for_attach: whether to use the original branch opener digit when adding back
mapping number to attachment points, or use simple enumeration.

"""
self.slicer = slicer
Expand All @@ -82,6 +85,7 @@ def __init__(
if isinstance(self.slicer, (list, tuple)):
self.slicer = [dm.from_smarts(x) for x in self.slicer]
self.require_hs = require_hs or (slicer == "attach")
self.use_original_opener_for_attach = use_original_opener_for_attach

@staticmethod
def randomize(mol: dm.Mol, rng: Optional[int] = None):
Expand Down Expand Up @@ -124,6 +128,7 @@ def _ensure_valid(self, inp: str):

Args:
inp: input SAFE string

"""
missing_tokens = [inp]
branch_numbers = self._find_branch_number(inp)
Expand All @@ -133,8 +138,11 @@ def _ensure_valid(self, inp: str):
for i, (bnum, bcount) in enumerate(branch_numbers.items()):
if bcount % 2 != 0:
bnum_str = str(bnum) if bnum < 10 else f"%{bnum}"
missing_tokens.append(f"[*:{i+1}]{bnum_str}")

_tk = f"[*:{i+1}]{bnum_str}"
if self.use_original_opener_for_attach:
bnum_digit = bnum_str.strip("%") # strip out the % sign
_tk = f"[*:{bnum_digit}]{bnum_str}"
missing_tokens.append(_tk)
return ".".join(missing_tokens)

def decoder(
Expand Down Expand Up @@ -221,6 +229,7 @@ def encoder(
seed: Optional[int] = None,
constraints: Optional[List[dm.Mol]] = None,
allow_empty: bool = False,
rdkit_safe: bool = True,
):
"""Convert input smiles to SAFE representation

Expand All @@ -235,6 +244,7 @@ def encoder(
constraints: List of molecules or pattern to preserve during the SAFE construction. Any bond slicing would
happen outside of a substructure matching one of the patterns.
allow_empty: whether to allow the slicing algorithm to return empty bonds
rdkit_safe: whether to apply rdkit-safe digit standardization to the output SAFE string.
"""
rng = None
if randomize:
Expand Down Expand Up @@ -327,9 +337,15 @@ def encoder(
attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
scaffold_str = attach_regexp.sub(val, scaffold_str)
starting_num += 1
# now we need to remove all the parenthesis around difig only number
# now we need to remove all the parenthesis around digit only number
wrong_attach = re.compile(r"\(([\%\d]*)\)")
return wrong_attach.sub(r"\g<1>", scaffold_str)
scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
maclandrol marked this conversation as resolved.
Show resolved Hide resolved
# furthermore, we autoapply rdkit-compatible digit standardization.
if rdkit_safe:
pattern = r"\(([=-@#]?)(%?\d{1,2})\)"
replacement = r"\g<1>\g<2>"
scaffold_str = re.sub(pattern, replacement, scaffold_str)
return scaffold_str


def encode(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,29 @@ def test_safe_decoder():
decoded_fragments = [safe.decode(x, fix=True) for x in fragments]
assert [dm.to_mol(x) for x in fragments] == [None] * len(fragments)
assert all(x is not None for x in decoded_fragments)


def test_rdkit_smiles_parser_issues():
# see https://github.com/datamol-io/safe/issues/22
input_sm = r"C(=C/c1ccccc1)\CCc1ccccc1"
slicer = "brics"
safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False)
with dm.without_rdkit_log():
failing_encoded = safe_obj.encoder(
input_sm,
canonical=True,
randomize=False,
rdkit_safe=False,
)
working_encoded = safe_obj.encoder(
input_sm,
canonical=True,
randomize=False,
rdkit_safe=True,
)
working_decoded = safe.decode(working_encoded)
working_no_stero = dm.remove_stereochemistry(dm.to_mol(input_sm))
input_mol = dm.remove_stereochemistry(dm.to_mol(working_decoded))
assert safe.decode(failing_encoded) is None
assert working_decoded is not None
assert dm.same_mol(working_no_stero, input_mol)