Source code for schrodinger.structutils.shuffle

import random
import uuid

from schrodinger import structure
from schrodinger.infra import structure as infrastructure
from schrodinger.utils import fileutils
from schrodinger.utils.fileutils import MAESTRO
from schrodinger.utils.fileutils import SD
from schrodinger.utils.fileutils import get_structure_file_format


[docs]class ShuffleException(Exception): pass
def _check_files(files): for file in files: if not fileutils.is_maestro_file(file) and not fileutils.is_sd_file( file): msg = "File: %s must be in Maestro or SD format" % file raise ValueError(msg)
[docs]def split_and_shuffle(infile, batch_size=1e9, random_seed=None): """ Split structures in infile to batches. Shuffle structures in each batch. Write each batch to a temporary file. :param infile: input structure file. :type infile: MAESTRO or SD file :param batch_size: size of temporary split sub files in bytes. :type batch_size: int :param random_seed: random seed number for shuffling the ligands :type random_seed: int or None :return: list of file names, list of number of structures in each file. :rtype: list(str), list(int) """ rand = random.Random(random_seed) fileformat = fileutils.get_structure_file_format(infile) if fileformat not in (MAESTRO, SD): msg = f"File: {infile} must be in Maestro or SD format" raise ShuffleException(msg) file_suffix = "mae" if fileformat == MAESTRO else "sdf" unique_id = uuid.uuid4() def write_batch(): nonlocal current_batch_size rand.shuffle(st_batch) fname = f'batch-{len(batch_files)}-{unique_id}.{file_suffix}' with infrastructure.TextBlockWriter(fname) as writer: for st in st_batch: writer.append(st) batch_files.append(fname) st_batch.clear() current_batch_size = 0 current_batch_size = 0 st_batch = [] num_st_in_batch = [] batch_files = [] with infrastructure.TextBlockReader(infile) as reader: for st_text in reader: current_batch_size += len(st_text) st_batch.append(st_text) if current_batch_size > batch_size: num_st_in_batch.append(len(st_batch)) write_batch() if st_batch: num_st_in_batch.append(len(st_batch)) write_batch() return batch_files, num_st_in_batch
[docs]def shuffle_merge(batch_files, num_st_in_batch, outfile, max_structs, random_seed=None): """ Merge structures in temporary files by picking a structure from each temporary file at random (weighed by the number of structures in the file). :param batch_files: list of file names :type batch_files: list(str) :param num_st_in_batch: list of number of structures in each file. :type num_st_in_batch: list(int) :param outfile: output file name :type outfile: str :param max_structs: max. number of structures to write (negative means all) :type max_structs: int :param random_seed: random seed number for shuffling the ligands. :type random_seed: int or None """ rand = random.Random(random_seed) n = sum(num_st_in_batch) num_dumped_st = 0 batch_indexes = range(len(batch_files)) use_text = get_structure_file_format(batch_files[0]) == \ get_structure_file_format(outfile) if use_text: readers = [ infrastructure.TextBlockReader(filename) for filename in batch_files ] writer = infrastructure.TextBlockWriter(outfile) else: readers = [ structure.StructureReader(filename) for filename in batch_files ] writer = structure.StructureWriter(outfile) while n > 0: idx = rand.choices(batch_indexes, num_st_in_batch)[0] st = next(readers[idx]) n -= 1 num_dumped_st += 1 num_st_in_batch[idx] -= 1 writer.append(st) if max_structs >= 0 and num_dumped_st >= max_structs: break
[docs]def shuffle_structs(input_file, output_file, max_structs=-1, batch_size=1e9, random_seed=None): """ Structures in input_file are shuffled and outputed to the output file :param input_file: input filename :type input_file: str :param output_file: output filename :type output_file: str :param max_structs: max. number of structures to write (negative means all) :type max_structs: int :param batch_size: size of temporary split sub files in bytes. :type batch_size: int :param random_seed: random seed number for shuffling the ligands :type random_seed: int or None """ try: _check_files([input_file, output_file]) except ValueError as Error: raise ShuffleException(Error) batch_files = [] try: batch_files, num_st_in_batch = split_and_shuffle( input_file, batch_size, random_seed) shuffle_merge(batch_files, num_st_in_batch, output_file, max_structs, random_seed) except IOError as Error: raise ShuffleException(Error) finally: fileutils.force_remove(*batch_files)