Source code for schrodinger.application.scaffold_enumeration.linknode

'''
Implements "link node"/"repearing unit" enumeration (see ENUM-253).
'''

import collections

import rdkit.Chem

from schrodinger.utils import log

from . import common

logger = log.get_output_logger(__name__)

#------------------------------------------------------------------------------#
#
# from https://chemaxon.com/marvin-archive/5.8.2/marvin/help/formats/mrv-doc-old.html
#
# <molecule
#    id="sg1"
#    role="SruSgroup"
#    title="name"
#    molID="m1"
#    atomRefs="a1 a2 ... "
#    correspondence="b1 b2 ... "
#    bondList="b1 b2 ... "
#    connect="hh|ht|eu ">
#
# is a "SRU S-group"; `name` is something like "1-2,5,100500"
# for `correspondence` see MDL M CRS, for `bondlist` see MDL M SBL,
# for `connect` see MDL M SCN
#
# * this module uses only role/title/atomRefs and connect (I could not
#   get the remaining attributes populated by MarvinSketcher anyway)
#
# * assume that connect=="hh" implies "head-to-head" connections, other
#   choices ["(ht)" and "eu" (either/unknown)] imply "head-to-tail"
#
# * require exactly two "crossing" bonds

LinknodeSgroup = collections.namedtuple(
    'LinknodeSgroup', [
        'atoms',   # set of atom indices
        'repeats', # list of possible repeat counts
        'connect'  # heat-to-tail vs heat-to-head linking
]) # yapf: disable

#------------------------------------------------------------------------------#


def _is_linknode_sgroup(g):
    '''
    :param g: MRV S-group as dictionary.
    :type g: dict
    '''

    return g.get('role', '') == 'SruSgroup'


#------------------------------------------------------------------------------#


def _drop_linknode_sgroups(mol):
    '''
    Remove linknode S-groups from the S-groups associated with `mol`.
    '''

    common.set_sgroups(
        mol, [g for g in common.get_sgroups(mol) if not _is_linknode_sgroup(g)])


#------------------------------------------------------------------------------#


def _parse_repeats(text, max_variations=1 << 10):
    '''
    Parses strings like '1-4,10,100-500' into lists of corresponding integers.

    :param text: String to parse.
    :type text: str

    :param maxlen: Greatest acceptable number of the integers.
    :type maxlen: int

    :return: List of integers.
    :rtype: list(int)
    '''

    repeats = set()

    for chunk in text.split(','):
        try:
            toks = list(map(int, chunk.split('-')))
        except ValueError:
            logger.warning('unexpected number of repeats: %s', text)
            continue

        if len(toks) == 1:
            lo, hi = toks[0], toks[0]
        elif len(toks) == 2:
            lo, hi = toks[0], toks[1]
        else:
            logger.warning('unexpected format of repeats: %s', text)
            continue

        for r in range(lo, hi + 1):
            if len(repeats) >= max_variations:
                logger.warning('ignoring all but %d repeat variations',
                               max_variations)
                break
            if r > 0:
                repeats.add(r)

    return sorted(repeats)


#------------------------------------------------------------------------------#


def _collect_linknodes(mol):
    '''
    Collects "linknode" specs from the molecule.

    :param mol: Molecule to be considered.
    :type mol: rdkit.Chem.Mol

    :return: List of "linknodes".
    :rtype: list(LinknodeSgroup)
    '''

    a2idx = common.get_atom_id_map(mol)

    outcome = []

    for g in common.get_sgroups(mol):
        if _is_linknode_sgroup(g):
            try:
                atom_ids = set(g['atomRefs'].split())
                connect = g['connect']
                title = g['title']
            except KeyError:
                logger.warning('incomplete linknode S-group: %s', g)
                continue
            try:
                atoms = set(a2idx[a] for a in atom_ids)
            except KeyError:
                logger.warning('undefined atom id(s) in linknode S-group: %s',
                               g)
                continue
            repeats = _parse_repeats(title)
            if atoms and repeats:
                outcome.append(
                    LinknodeSgroup(atoms=atoms,
                                   repeats=repeats,
                                   connect=connect))

    return outcome


#------------------------------------------------------------------------------#


def _is_connected_subgraph(mol, atom_indices_set):
    '''
    Determines whether `atoms` form connected subgraph in `mol`.

    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol

    :param atom_indices_set: Set atom indices in `mol`.
    :type atom_indices_set: set(int)
    '''

    for i in atom_indices_set:
        atom = mol.GetAtomWithIdx(i)
        for neighbor in atom.GetNeighbors():
            if neighbor.GetIdx() in atom_indices_set:
                break
        else:
            return len(atom_indices_set) == 1

    return True


#------------------------------------------------------------------------------#


def _get_crossing_bonds(mol, atom_indices_set):
    '''
    Returns list of bonds that have one atom from `atom_indices_set`.

    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol

    :param atom_indices_set: Set of atom indices in `mol`.
    :type atom_indices_set: set(int)

    :return: List of bonds.
    :rtype: list(rdkit.Chem.Bond)
    '''

    outcome = []

    for bond in mol.GetBonds():
        atom1_is_in = bond.GetBeginAtomIdx() in atom_indices_set
        atom2_is_in = bond.GetEndAtomIdx() in atom_indices_set
        if atom1_is_in != atom2_is_in:
            outcome.append(bond)

    return outcome


#------------------------------------------------------------------------------#


def _validate_linknodes(mol, linknodes):
    '''
    Validates the `linknodes`.

    :param mol: Molecule.
    :type mol: rdkit.Chem.Mol

    :param nodes: List of "linknodes" to validate.
    :type nodes: list(LinknodeSgroup)

    :return: Validation success and error message.
    :rtype: (bool, str)
    '''

    num_atoms = mol.GetNumAtoms()

    for node in linknodes:
        if not node.atoms:
            return False, 'no atoms'
        for i in node.atoms:
            if not (i >= 0 and i < num_atoms):
                return False, 'invalid atom indices'
        if not _is_connected_subgraph(mol, node.atoms):
            return False, 'atoms do not form connected subgraph'
        if len(_get_crossing_bonds(mol, node.atoms)) != 2:
            return False, 'unexpected number of crossing bonds'
        if not all(r > 0 for r in node.repeats):
            return False, 'non-positive repeat count(s)'

    num_nodes = len(linknodes)

    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if linknodes[i].atoms & linknodes[j].atoms:
                return False, 'link nodes must not overlap'

    return True, ''


#------------------------------------------------------------------------------#


def _clone_atom(src):
    '''
    Strives to create a copy of the `src` atom.

    :param src: Atom to be copied.
    :type src: rdkit.Chem.Atom

    :return: Copy of the `src` atom.
    :rtype: rdkit.Chem.Atom
    '''

    clone = rdkit.Chem.Atom(src)
    clone.ClearProp(common.CML_ID_PROP)

    return clone


#------------------------------------------------------------------------------#


def _clone_bond(rwmol, bond_idx, beg_atom_idx, end_atom_idx):
    '''
    Strives to create a bond in `rwmol` between
    `beg_atom_idx` and `end_atom_idx` that is otherwise identical
    to the one with index `bond_idx`.

    :param rwmol: Molecule to be modified.
    :type rwmol: rdkit.Chem.RWMol

    :param bond_idx: Index of the bond to be copied.
    :type bond_idx: int

    :param beg_atom_idx: Index of the "begin" atom for the new bond.
    :type beg_atom_idx: int

    :param end_atom_idx: Index of the "end" atom for the new bond.
    :type end_atom_idx: int
    '''

    original_bond = rwmol.GetBondWithIdx(bond_idx)
    num_bonds = rwmol.AddBond(beg_atom_idx, end_atom_idx)
    rwmol.ReplaceBond(num_bonds - 1, original_bond)


#------------------------------------------------------------------------------#


def _replicate(rwmol, atom_indices, head_to_head=False):
    '''
    Replicate atoms with indices from `atom_indices` once. Assumes that
    the "repeating unit" defined by the `atom_indices` forms connected
    subgraph in `mol` and had exactly two "crossing" bonds.

    :param rwmol: Molecule to be modified.
    :type rwmol: rdkit.Chem.RWMol

    :param atom_indices: Indices of the atoms to be "replicated".
    :type atom_indices: set(int)

    :param head_to_head: Replicate "head-to-head" instead of "head-to-tail".
    :type head_to_head: bool
    '''

    # atoms

    o2n = dict()  # "orig" -> "new" atom index map
    for idx in atom_indices:
        o2n[idx] = rwmol.AddAtom(_clone_atom(rwmol.GetAtomWithIdx(idx)))

    # bonds

    BondInfo = collections.namedtuple('BondInfo',
                                      ['idx', 'beg_atom_idx', 'end_atom_idx'])

    bonds_to_copy = []  # beg_atom_idx and end_atom_idx in atom_indices
    crossing_bonds = []  # beg_atom_idx in atom_indices, end_atom_idx is not

    for bond in rwmol.GetBonds():
        beg_atom_idx = bond.GetBeginAtomIdx()
        beg_atom_cloned = beg_atom_idx in o2n
        end_atom_idx = bond.GetEndAtomIdx()
        end_atom_cloned = end_atom_idx in o2n
        if beg_atom_cloned and end_atom_cloned:
            bonds_to_copy.append(
                BondInfo(bond.GetIdx(), beg_atom_idx, end_atom_idx))
        elif beg_atom_cloned and not end_atom_cloned:
            crossing_bonds.append(
                BondInfo(bond.GetIdx(), beg_atom_idx, end_atom_idx))
        elif end_atom_cloned and not beg_atom_cloned:
            crossing_bonds.append(
                BondInfo(bond.GetIdx(), end_atom_idx, beg_atom_idx))

    for bond_info in bonds_to_copy:
        _clone_bond(rwmol, bond_info.idx, o2n[bond_info.beg_atom_idx],
                    o2n[bond_info.end_atom_idx])

    # head/tail atoms of the "replica"

    new_ht_atoms = [o2n[bond_info.beg_atom_idx] for bond_info in crossing_bonds]
    if head_to_head:
        new_ht_atoms.reverse()

    #
    # [new_head]...[new_tail] \
    #                          => X-[head]...[tail]-[new_head]...[new_tail]-Y
    #   X-[head]...[tail]-Y   /
    #
    # here: X-[head] is crossing_bonds[0] (with beg_atom_idx == head),
    #       [tail]-Y is crossing_bonds[1] (with beg_atom_idx == tail)
    #

    # create [tail]-[new_head]

    _clone_bond(rwmol, crossing_bonds[1].idx, crossing_bonds[1].beg_atom_idx,
                new_ht_atoms[0])

    # create [new_tail]-Y

    _clone_bond(rwmol, crossing_bonds[0].idx, new_ht_atoms[1],
                crossing_bonds[1].end_atom_idx)

    # drop [tail]-Y

    rwmol.RemoveBond(crossing_bonds[1].beg_atom_idx,
                     crossing_bonds[1].end_atom_idx)


#------------------------------------------------------------------------------#


def _apply_linknode(rwmol, atom_indices, num_repeats, head_to_head=False):
    '''
    Repeats link node defined by `atom_indices` `num_repeats - 1` times.

    :param rwmol: R/W molecule.
    :type rwmol: rdkit.Chem.RWMol

    :param atom_indices: Set of atom indices in `rwmol` that define the
        "linknode" (assumed to be validated: form connected subgraph in
        `rwmol` and have exactly two crossing bonds).
    :type atoms: set(int)

    :param num_repeats: Number of repeats (positive integer).
    :type num_repeats: int

    :param head_to_head: Connect "head-to-head" vs "head-to-tail" otherwise.
    :type head_to_head: bool
    '''

    for i in range(1, num_repeats):
        _replicate(rwmol, atom_indices, head_to_head and
                   (num_repeats - i) % 2 == 1)


#------------------------------------------------------------------------------#


[docs]class LinknodeEnumerable(common.EnumerableMixin):
[docs] def __init__(self, mol, linknodes=None): ''' :param mol: RDKit molecule. :type mol: rdkit.Chem.Mol :param linknodes: List of link nodes. :type linknodes: list(LinknodeSgroup) ''' if linknodes is None: linknodes = _collect_linknodes(mol) valid, msg = _validate_linknodes(mol, linknodes) if not valid: raise ValueError('LinknodeEnumerable: ' + msg) self.linknodes = linknodes self.mol = mol
[docs] def getExtents(self): return [len(n.repeats) for n in self.linknodes]
[docs] def getRealization(self, idx): ''' :param idx: "Index" of a realization. :type idx: iterable over int :return: RDKit molecule without "link nodes". :rtype: rdkit.Chem.Mol ''' if self.linknodes: rwmol = rdkit.Chem.RWMol(self.mol) _drop_linknode_sgroups(rwmol) for (i, node) in zip(idx, self.linknodes): _apply_linknode(rwmol, node.atoms, node.repeats[i], head_to_head=(node.connect != '(ht)')) return rwmol.GetMol() else: return self.mol
#------------------------------------------------------------------------------#