Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
maclandrol committed Jan 2, 2024
1 parent 87c9c22 commit d0892ed
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
24 changes: 20 additions & 4 deletions safe/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self,
slicer: Optional[Union[str, List[str], Callable]] = "brics",
require_hs: Optional[bool] = None,
use_original_opener_for_attach: bool = True,
):
"""Constructor for the SAFE converter
Expand All @@ -72,6 +73,8 @@ def __init__(
or a custom callable that returns the bond ids that can be sliced.
require_hs: whether the slicing algorithm require the molecule to have hydrogen explictly added.
`attach` slicer requires adding hydrogens.
use_original_opener_for_attach: whether to use the original branch opener digit when adding back
mapping number to attachment points, or use simple enumeration.
"""
self.slicer = slicer
Expand All @@ -82,6 +85,7 @@ def __init__(
if isinstance(self.slicer, (list, tuple)):
self.slicer = [dm.from_smarts(x) for x in self.slicer]
self.require_hs = require_hs or (slicer == "attach")
self.use_original_opener_for_attach = use_original_opener_for_attach

@staticmethod
def randomize(mol: dm.Mol, rng: Optional[int] = None):
Expand Down Expand Up @@ -124,6 +128,7 @@ def _ensure_valid(self, inp: str):
Args:
inp: input SAFE string
"""
missing_tokens = [inp]
branch_numbers = self._find_branch_number(inp)
Expand All @@ -133,8 +138,11 @@ def _ensure_valid(self, inp: str):
for i, (bnum, bcount) in enumerate(branch_numbers.items()):
if bcount % 2 != 0:
bnum_str = str(bnum) if bnum < 10 else f"%{bnum}"
missing_tokens.append(f"[*:{i+1}]{bnum_str}")

_tk = f"[*:{i+1}]{bnum_str}"
if self.use_original_opener_for_attach:
bnum_digit = bnum_str.strip("%") # strip out the % sign
_tk = f"[*:{bnum_digit}]{bnum_str}"
missing_tokens.append(_tk)
return ".".join(missing_tokens)

def decoder(
Expand Down Expand Up @@ -221,6 +229,7 @@ def encoder(
seed: Optional[int] = None,
constraints: Optional[List[dm.Mol]] = None,
allow_empty: bool = False,
rdkit_safe: bool = True,
):
"""Convert input smiles to SAFE representation
Expand All @@ -235,6 +244,7 @@ def encoder(
constraints: List of molecules or pattern to preserve during the SAFE construction. Any bond slicing would
happen outside of a substructure matching one of the patterns.
allow_empty: whether to allow the slicing algorithm to return empty bonds
rdkit_safe: whether to apply rdkit-safe digit standardization to the output SAFE string.
"""
rng = None
if randomize:
Expand Down Expand Up @@ -327,9 +337,15 @@ def encoder(
attach_regexp = re.compile(r"(" + re.escape(attach) + r")")
scaffold_str = attach_regexp.sub(val, scaffold_str)
starting_num += 1
# now we need to remove all the parenthesis around difig only number
# now we need to remove all the parenthesis around digit only number
wrong_attach = re.compile(r"\(([\%\d]*)\)")
return wrong_attach.sub(r"\g<1>", scaffold_str)
scaffold_str = wrong_attach.sub(r"\g<1>", scaffold_str)
# furthermore, we autoapply rdkit-compatible digit standardization.
if rdkit_safe:
pattern = r"\(([=-@#]?)(%?\d{1,2})\)"
replacement = r"\g<1>\g<2>"
scaffold_str = re.sub(pattern, replacement, scaffold_str)
return scaffold_str


def encode(
Expand Down
26 changes: 26 additions & 0 deletions tests/test_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,29 @@ def test_safe_decoder():
decoded_fragments = [safe.decode(x, fix=True) for x in fragments]
assert [dm.to_mol(x) for x in fragments] == [None] * len(fragments)
assert all(x is not None for x in decoded_fragments)


def test_rdkit_smiles_parser_issues():
# see https://github.com/datamol-io/safe/issues/22
input_sm = r"C(=C/c1ccccc1)\CCc1ccccc1"
slicer = "brics"
safe_obj = safe.SAFEConverter(slicer=slicer, require_hs=False)
with dm.without_rdkit_log():
failing_encoded = safe_obj.encoder(
input_sm,
canonical=True,
randomize=False,
rdkit_safe=False,
)
working_encoded = safe_obj.encoder(
input_sm,
canonical=True,
randomize=False,
rdkit_safe=True,
)
working_decoded = safe.decode(working_encoded)
working_no_stero = dm.remove_stereochemistry(dm.to_mol(input_sm))
input_mol = dm.remove_stereochemistry(dm.to_mol(working_decoded))
assert safe.decode(failing_encoded) is None
assert working_decoded is not None
assert dm.same_mol(working_no_stero, input_mol)

0 comments on commit d0892ed

Please sign in to comment.