Skip to content

Commit 2bebbfb

Browse files
authored
Merge pull request #10 from DeepRank/unit_test
Add unit test for transform
2 parents fdfde48 + c96fdd9 commit 2bebbfb

File tree

6 files changed

+280
-4921
lines changed

6 files changed

+280
-4921
lines changed

pdb2sql/5hvd.pdb

Lines changed: 0 additions & 4834 deletions
This file was deleted.

pdb2sql/StructureSimilarity.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self,decoy,ref,verbose=False):
5151
self.decoy = decoy
5252
self.ref = ref
5353
self.verbose = verbose
54+
self.origin = [0., 0., 0.]
5455

5556

5657
################################################################################################
@@ -155,7 +156,7 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True):
155156
U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method)
156157

157158
# rotate the entire fragment
158-
xyz_decoy_short = transform.rotation_matrix(xyz_decoy_short,U,center=False)
159+
xyz_decoy_short = transform.rotate(xyz_decoy_short,U, center=self.origin)
159160

160161
# compute the RMSD
161162
return self.get_rmsd(xyz_decoy_short,xyz_ref_short)
@@ -293,7 +294,7 @@ def compute_irmsd_fast(self,izone=None,method='svd',cutoff=10,check=True):
293294
U = self.get_rotation_matrix(xyz_contact_decoy,xyz_contact_ref,method=method)
294295

295296
# rotate the entire fragment
296-
xyz_contact_decoy = transform.rotation_matrix(xyz_contact_decoy,U,center=False)
297+
xyz_contact_decoy = transform.rotate(xyz_contact_decoy,U,center=self.origin)
297298

298299
# return the RMSD
299300
return self.get_rmsd(xyz_contact_decoy,xyz_contact_ref)
@@ -551,7 +552,7 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'):
551552
U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method)
552553

553554
# rotate the entire fragment
554-
xyz_decoy_short = transform.rotation_matrix(xyz_decoy_short,U,center=False)
555+
xyz_decoy_short = transform.rotate(xyz_decoy_short, U, center=self.origin)
555556

556557

557558
# compute the RMSD
@@ -569,7 +570,7 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'):
569570
xyz_decoy += tr_decoy
570571

571572
# rotate decoy
572-
xyz_decoy = transform.rotation_matrix(xyz_decoy,U,center=False)
573+
xyz_decoy = transform.rotate(xyz_decoy, U, center=self.origin)
573574

574575
# update the sql database
575576
sql_decoy.update_column('x',xyz_decoy[:,0])
@@ -726,7 +727,7 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None
726727
U = self.get_rotation_matrix(xyz_contact_decoy,xyz_contact_ref,method=method)
727728

728729
# rotate the entire fragment
729-
xyz_contact_decoy = transform.rotation_matrix(xyz_contact_decoy,U,center=False)
730+
xyz_contact_decoy = transform.rotate(xyz_contact_decoy, U, center=self.origin)
730731

731732
# compute the RMSD
732733
irmsd = self.get_rmsd(xyz_contact_decoy,xyz_contact_ref)

pdb2sql/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.2.0'
1+
__version__ = '0.2.1'

pdb2sql/transform.py

Lines changed: 93 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -8,73 +8,51 @@
88
definition of the data set.
99
'''
1010

11-
def get_rot_axis_angle(seed=None):
12-
"""Get the rotation angle/axis.
13-
14-
Args:
15-
seed(int): random seed for numpy
16-
17-
Returns:
18-
list(float): axis of rotation
19-
float: angle of rotation
20-
"""
21-
# define the axis
22-
# uniform distribution on a sphere
23-
# http://mathworld.wolfram.com/SpherePointPicking.html
24-
if seed != None:
25-
np.random.seed(seed)
26-
27-
u1, u2 = np.random.rand(), np.random.rand()
28-
teta, phi = np.arccos(2 * u1 - 1), 2 * np.pi * u2
29-
axis = [np.sin(teta) * np.cos(phi),
30-
np.sin(teta) * np.sin(phi),
31-
np.cos(teta)]
32-
33-
# and the rotation angle
34-
angle = -np.pi + np.pi * np.random.rand()
35-
36-
return axis, angle
37-
38-
11+
########################################################################
12+
# Translation
13+
########################################################################
3914
def translation(db, vect, **kwargs):
4015
xyz = _get_xyz(db, **kwargs)
4116
xyz += vect
4217
_update(db, xyz, **kwargs)
4318

44-
19+
########################################################################
20+
# Rotation using axis–angle presentation
21+
# see https://en.wikipedia.org/wiki/Rotation_matrix#Rotation_matrix_from_axis_and_angle
22+
########################################################################
4523
def rot_axis(db, axis, angle, **kwargs):
4624
xyz = _get_xyz(db, **kwargs)
4725
xyz = rot_xyz_around_axis(xyz, axis, angle)
4826
_update(db, xyz, **kwargs)
4927

50-
51-
def rot_euler(db, alpha, beta, gamma, **kwargs):
52-
"""Rotate molecule from Euler rotation axis.
28+
def get_rot_axis_angle(seed=None):
29+
"""Get the rotation angle and axis.
5330
5431
Args:
55-
alpha (float): angle of rotation around the x axis
56-
beta (float): angle of rotation around the y axis
57-
gamma (float): angle of rotation around the z axis
58-
**kwargs: keyword argument to select the atoms.
59-
See pdb2sql.get()
60-
"""
61-
xyz = _get_xyz(db, **kwargs)
62-
xyz = _rotation_euler(xyz, alpha, beta, gamma)
63-
_update(db, xyz, **kwargs)
32+
seed(int): random seed for numpy
6433
34+
Returns:
35+
list(float): axis of rotation
36+
float: angle of rotation
37+
"""
38+
if seed is not None:
39+
np.random.seed(seed)
6540

66-
def rot_mat(db, mat, **kwargs):
67-
"""Rotate molecule from a rotation matrix.
41+
# define the rotation axis
42+
# uniform distribution on a sphere
43+
# eq1,2 in http://mathworld.wolfram.com/SpherePointPicking.html
44+
u1, u2 = np.random.rand(), np.random.rand()
45+
theta = 2 * np.pi * u1 # [0, 2*pi)
46+
phi = np.arccos(2 * u2 - 1) # [0, pi]
47+
# eq19 in http://mathworld.wolfram.com/SphericalCoordinates.html
48+
axis = [np.sin(phi) * np.cos(theta),
49+
np.sin(phi) * np.sin(theta),
50+
np.cos(phi)]
6851

69-
Args:
70-
mat (np.array): 3x3 rotation matrix
71-
**kwargs: keyword argument to select the atoms.
72-
See pdb2sql.get()
73-
"""
74-
xyz = _get_xyz(db, **kwargs)
75-
xyz = _rotation_matrix(xyz, mat)
76-
_update(db, xyz, **kwargs)
52+
# define the rotation angle
53+
angle = 2 * np.pi * np.random.rand()
7754

55+
return axis, angle
7856

7957
def rot_xyz_around_axis(xyz, axis, angle, center=None):
8058
"""Get the rotated xyz.
@@ -89,17 +67,11 @@ def rot_xyz_around_axis(xyz, axis, angle, center=None):
8967
Returns:
9068
np.array: rotated xyz coordinates
9169
"""
92-
93-
# check center
94-
if center is None:
95-
center = np.mean(xyz, 0)
96-
9770
# get the data
9871
ct, st = np.cos(angle), np.sin(angle)
9972
ux, uy, uz = axis
10073

10174
# definition of the rotation matrix
102-
# see https://en.wikipedia.org/wiki/Rotation_matrix
10375
rot_mat = np.array([[ct + ux**2 * (1 - ct),
10476
ux * uy * (1 - ct) - uz * st,
10577
ux * uz * (1 - ct) + uy * st],
@@ -111,51 +83,91 @@ def rot_xyz_around_axis(xyz, axis, angle, center=None):
11183
ct + uz**2 * (1 - ct)]])
11284

11385
# apply the rotation
114-
return np.dot(rot_mat, (xyz - center).T).T + center
86+
return rotate(xyz, rot_mat, center)
11587

88+
########################################################################
89+
# Rotation using Euler anlges
90+
# see https://en.wikipedia.org/wiki/Rotation_matrix#General_rotations
91+
########################################################################
11692

117-
def _rotation_euler(xyz, alpha, beta, gamma):
93+
def rot_euler(db, alpha, beta, gamma, **kwargs):
94+
"""Rotate molecule from Euler rotation axis.
95+
96+
Args:
97+
alpha (float): angle of rotation around the x axis
98+
beta (float): angle of rotation around the y axis
99+
gamma (float): angle of rotation around the z axis
100+
**kwargs: keyword argument to select the atoms.
101+
See pdb2sql.get()
102+
"""
103+
xyz = _get_xyz(db, **kwargs)
104+
xyz = rotation_euler(xyz, alpha, beta, gamma)
105+
_update(db, xyz, **kwargs)
106+
107+
def rotation_euler(xyz, alpha, beta, gamma, center=None):
118108

119109
# precomte the trig
120110
ca, sa = np.cos(alpha), np.sin(alpha)
121111
cb, sb = np.cos(beta), np.sin(beta)
122112
cg, sg = np.cos(gamma), np.sin(gamma)
123113

124-
# get the center of the molecule
125-
xyz0 = np.mean(xyz, 0)
126-
127114
# rotation matrices
128115
rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]])
129116
ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]])
130-
rz = np.array([[cg, -sg, 0], [sg, cs, 0], [0, 0, 1]])
117+
rz = np.array([[cg, -sg, 0], [sg, cg, 0], [0, 0, 1]])
131118

132-
rot_mat = np.dot(rz, np.dot(ry, rz))
119+
# get rotation matrix
120+
rot_mat = np.dot(rz, np.dot(ry, rx))
133121

134122
# apply the rotation
135-
return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0
123+
return rotate(xyz, rot_mat, center)
136124

125+
########################################################################
126+
# Rotation using provided rotation matrix
127+
########################################################################
137128

138-
def rotation_matrix(xyz, rot_mat, center=True):
139-
if center:
140-
xyz0 = np.mean(xyz)
141-
return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0
142-
else:
143-
return np.dot(rot_mat, (xyz).T).T
129+
def rot_mat(db, mat, **kwargs):
130+
"""Rotate molecule from a rotation matrix.
144131
132+
Args:
133+
mat (np.array): 3x3 rotation matrix
134+
**kwargs: keyword argument to select the atoms.
135+
See pdb2sql.get()
136+
"""
137+
xyz = _get_xyz(db, **kwargs)
138+
xyz = rotate(xyz, mat)
139+
_update(db, xyz, **kwargs)
145140

146-
def _get_xyz(db, **kwargs):
147-
return np.array(db.get('x,y,z', **kwargs))
141+
def rotate(xyz, rot_mat, center=None):
142+
"""[summary]
148143
144+
Args:
145+
xyz(np.ndarray): x,y,z coordinates
146+
rot_mat(np.ndarray): rotation matrix
147+
center (list or np.ndarray, optional): rotation center.
148+
Defaults to None, i.e. using molecule center as rotation
149+
center.
149150
150-
def _update(db, xyz, **kwargs):
151-
db.update('x,y,z', xyz, **kwargs)
151+
Raises:
152+
TypeError: Rotation center must be list or 1D np.ndarray.
152153
154+
Returns:
155+
np.ndarray: x,y,z coordinates after rotation
156+
"""
157+
# the default rotation center is the center of molecule itself.
158+
if center is None:
159+
center = np.mean(xyz, 0)
153160

154-
if __name__ == "__main__":
161+
if not isinstance(center, (list, np.ndarray)):
162+
raise TypeError("Rotation center must be list or 1D np.ndarray")
155163

156-
t0 = time()
157-
db = pdb2sql('5hvd.pdb')
158-
print('SQL %f' % (time() - t0))
164+
return np.dot(rot_mat, (xyz - center).T).T + center
159165

160-
tr = np.array([1, 2, 3])
161-
translation(db, tr, chainID='A')
166+
########################################################################
167+
# helper functions
168+
########################################################################
169+
def _get_xyz(db, **kwargs):
170+
return np.array(db.get('x,y,z', **kwargs))
171+
172+
def _update(db, xyz, **kwargs):
173+
db.update('x,y,z', xyz, **kwargs)

test/pdb/dummy_transform.pdb

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
ATOM 1 N THR A 1 1.000 0.000 0.000 1.00 0.69 N
2+
ATOM 2 CA THR A 1 -1.000 0.000 0.000 1.00 0.50 C
3+
ATOM 3 C THR A 1 0.000 1.000 0.000 1.00 0.45 C
4+
ATOM 4 O THR A 1 0.000 -1.000 0.000 1.00 0.69 O
5+
ATOM 5 CB THR A 1 0.000 0.000 1.000 1.00 0.50 C
6+
ATOM 6 H1 THR A 1 0.000 0.000 -1.000 1.00 0.45 H

0 commit comments

Comments
 (0)