"""
Create and filter 3D histogram with gaussian function.
Copyright Schrodinger, LLC. All rights reserved.
"""
# Contributors: Byungchan Kim
import math
import optparse
import sys
from past.utils import old_div
import numpy
from schrodinger import structure
[docs]class GaussianFilter:
[docs] def __init__(self, s=0.5, sp=3):
self.sigma = s
self.span = sp
self.ndim = 2 * sp + 1
self.filter = numpy.arange(float(self.ndim)**3).reshape(
self.ndim, self.ndim, self.ndim)
self.histogram = None
self.edges = None
self.filtered_histogram = None
self._prepare_filter()
def _prepare_filter(self):
s = self.sigma
A = math.sqrt(2 * math.pi) * s
A = old_div(1.0, A**3)
for i in range(self.ndim):
z = i - self.span
for j in range(self.ndim):
y = j - self.span
for k in range(self.ndim):
x = k - self.span
d2 = x * x + y * y + z * z
self.filter[i, j, k] = A
self.filter[i, j, k] *= math.exp(-0.5 * d2)
# Normalize
tot = self.filter.sum()
for i in range(self.ndim):
for j in range(self.ndim):
for k in range(self.ndim):
self.filter[i, j, k] /= tot
def _apply_filter(self):
s = self.span
his = numpy.copy(self.histogram)
shape = self.histogram.shape
shape = numpy.array(shape)
self.histogram.fill(0.0)
for i in range(shape[0]):
for j in range(shape[1]):
for k in range(shape[2]):
hijk = his[i, j, k]
if hijk == 0.0:
continue
for l in range(self.ndim): # noqa: E741
for m in range(self.ndim):
for n in range(self.ndim):
if ((0 <= i + l - s and i + l - s < shape[0])
and (0 <= j + m - s and
j + m - s < shape[1]) and
(0 <= k + n - s and k + n - s < shape[2])):
self.histogram[
i + l - s, j + m - s, k + n -
s] += hijk * self.filter[l, m, n]
[docs] def apply(self, atom_array, weights=None, max=None, min=None):
# transform X, Y, Z to Z, Y, X
(col1, col2, col3) = numpy.hsplit(atom_array, 3)
atom_array = numpy.hstack((col3, col2, col1))
if max is None or min is None:
min = atom_array.min(axis=0) - (self.span + 1) * self.sigma
max = atom_array.max(axis=0) + (self.span + 1) * self.sigma
min = numpy.floor(old_div(
min, self.sigma)) * self.sigma - 0.5 * self.sigma
max = numpy.ceil(old_div(
max, self.sigma)) * self.sigma + 0.5 * self.sigma
bins = [int(i) for i in ((max - min) / self.sigma)]
histogram_tmp, self.edges = numpy.histogramdd(atom_array,
normed=False,
bins=bins,
range=numpy.column_stack(
(min, max)),
weights=weights)
if self.histogram is not None:
self.histogram = numpy.add(self.histogram, histogram_tmp)
else:
self.histogram = histogram_tmp
[docs] def write(self, cns_fname, remark=''):
self._apply_filter()
fh = open(cns_fname, 'w')
s = '\n'
s += '%d\n' % len(remark.splitlines())
if remark:
s += remark
for i in range(2, -1, -1):
s += '%8d' % (self.edges[i].size - 1)
s += '%8d' % (old_div(
(self.edges[i][0] + 0.5 * self.sigma), self.sigma))
s += '%8d' % (old_div(
(self.edges[i][-1] - 0.5 * self.sigma), self.sigma))
s += '\n'
for i in range(2, -1, -1):
s += '%12.5E' % (self.edges[i][-1] - self.edges[i][0])
for i in range(3):
s += '%12.5E' % 0.90000E+02
s += '\n'
s += 'ZYX\n'
shape = self.histogram.shape
for i in range(shape[0]):
n = 1
s += '%8d\n' % i
for j in range(shape[1]):
for k in range(shape[2]):
s += '%12.5E' % self.histogram[i, j, k]
if n % 6 == 0:
s += '\n'
n += 1
if ((n - 1) % 6) != 0:
s += '\n'
s += '%8d\n' % -9999
s += '%12.5E%12.5E\n' % (numpy.mean(
self.histogram), numpy.std(self.histogram))
fh.write(s)
fh.close()
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option('-i',
'--input',
type='str',
default='',
help='input mae file name')
parser.add_option('-t',
'--trans_matrix',
type='str',
default='0 0 0',
help='transformation matrix')
parser.add_option('-g',
'--grid',
type='float',
default=0.5,
help='grid spacing size')
parser.add_option('-s', '--span', type='int', default=3, help='grid span')
parser.add_option('-o',
'--output',
type='str',
default='',
help='output cns file name')
opts, args = parser.parse_args()
locations = []
trans_m = opts.trans_matrix.split()
if len(trans_m) != 3:
print('Transformation matrix must have three elements.')
sys.exit(1)
trans_x = float(trans_m[0])
trans_y = float(trans_m[1])
trans_z = float(trans_m[2])
st_reader = structure.StructureReader(opts.input)
for st in st_reader:
for a in st.atom:
if a.atomic_number > 1:
locations.append((a.x + trans_x, a.y + trans_y, a.z + trans_z))
loc_array = numpy.array(locations)
gf = GaussianFilter(opts.grid, opts.span)
gf.apply(loc_array)
gf.write(opts.output)