Skip to content

Commit 0e41f11

Browse files
authored
Merge pull request #3 from DeepRank/multiple_interface
Multiple interface
2 parents e74673c + ccee66d commit 0e41f11

File tree

4 files changed

+189
-153
lines changed

4 files changed

+189
-153
lines changed

pdb2sql/StructureSimilarity.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,14 +185,12 @@ def compute_lzone(self,save_file=True,filename=None):
185185
return resData
186186

187187

188-
189-
190188
################################################################################################
191189
#
192190
# FAST ROUTINE TO COMPUTE THE I-RMSD
193-
# Require the precalculation of the izone
191+
# Require the precalculation of the izone
194192
# A dedicated routine is implemented to comoute the izone
195-
# if izone is not given in argument the routine will compute them autimatcally
193+
# if izone is not given in argument the routine will compute them automatcally
196194
#
197195
#################################################################################################
198196

pdb2sql/interface.py

Lines changed: 171 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
import numpy as np
3+
import itertools
34
from .pdb2sqlcore import pdb2sql
45

56
#from pdb2sqlAlchemy import pdb2sql_alchemy as pdb2sql
@@ -18,200 +19,228 @@
1819

1920
class interface(pdb2sql):
2021

21-
def __init__(self,pdb):
22+
def __init__(self,pdb):
2223

23-
super().__init__(pdb,no_extra=True)
24-
super()._create_sql()
25-
self.backbone_type = ['CA','C','N','O']
24+
super().__init__(pdb,no_extra=True)
25+
super()._create_sql()
26+
self.backbone_type = ['CA','C','N','O']
2627

27-
############################################################################
28-
#
29-
# get the contact atoms
30-
#
31-
#############################################################################
28+
############################################################################
29+
#
30+
# get the contact atoms
31+
#
32+
#############################################################################
3233

33-
def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B',
34-
extend_to_residue=False,only_backbone_atoms=False,
35-
excludeH=False,return_only_backbone_atoms=False,return_contact_pairs=False):
34+
def get_contact_atoms(self,cutoff=8.5,allchains=False,chain1='A',chain2='B',
35+
extend_to_residue=False,only_backbone_atoms=False,
36+
excludeH=False,return_only_backbone_atoms=False,return_contact_pairs=False):
3637

37-
# xyz of the chains
38-
xyz1 = np.array(super().get('x,y,z',chainID=chain1))
39-
xyz2 = np.array(super().get('x,y,z',chainID=chain2))
38+
if allchains:
39+
chainIDs = super().get_chains()
40+
else:
41+
chainIDs = [chain1,chain2]
42+
nchains = len(chainIDs)
4043

41-
# index of b
42-
index1 = super().get('rowID',chainID=chain1)
43-
index2 = super().get('rowID',chainID=chain2)
44+
xyz = dict()
45+
index = dict()
46+
resName = dict()
47+
atName = dict()
4448

45-
# resName of the chains
46-
resName1 = np.array(super().get('resName',chainID=chain1))
47-
resName2 = np.array(super().get('resName',chainID=chain2))
49+
for chain in chainIDs:
4850

49-
# atomnames of the chains
50-
atName1 = np.array(super().get('name',chainID=chain1))
51-
atName2 = np.array(super().get('name',chainID=chain2))
51+
data = np.array(super().get('x,y,z,rowID,resName,name',chainID=chain))
52+
xyz[chain] = data[:,:3].astype(float)
53+
index[chain] = data[:,3].astype(int)
54+
resName[chain] = data[:,-2]
55+
atName[chain] = data[:,-1]
5256

5357

54-
# loop through the first chain
55-
# TO DO : loop through the smallest chain instead ...
56-
index_contact_1,index_contact_2 = [],[]
57-
index_contact_pairs = {}
58+
# loop through the first chain
59+
# TO DO : loop through the smallest chain instead ...
60+
index_contact_1,index_contact_2 = [],[]
61+
index_contact_pairs = {}
5862

59-
for i,x0 in enumerate(xyz1):
63+
index_contact = dict()
64+
index_contact_pairs = {}
6065

61-
# compute the contact atoms
62-
contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) <= cutoff )[0]
66+
for chain1,chain2 in itertools.combinations(chainIDs,2):
6367

64-
# exclude the H if required
65-
if excludeH and atName1[i][0] == 'H':
66-
continue
68+
xyz1 = xyz[chain1]
69+
xyz2 = xyz[chain2]
6770

68-
if len(contacts)>0 and any([not only_backbone_atoms, atName1[i] in self.backbone_type]):
71+
atName1 = atName[chain1]
72+
atName2 = atName[chain2]
6973

70-
# the contact atoms
71-
index_contact_1 += [index1[i]]
72-
index_contact_2 += [index2[k] for k in contacts if ( any( [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0]=='H') ) ]
7374

74-
# the pairs
75-
pairs = [index2[k] for k in contacts if any( [atName2[k] in self.backbone_type, not only_backbone_atoms] ) and not (excludeH and atName2[k][0]=='H') ]
76-
if len(pairs) > 0:
77-
index_contact_pairs[index1[i]] = pairs
75+
if chain1 not in index_contact:
76+
index_contact[chain1] = []
7877

79-
# get uniques
80-
index_contact_1 = sorted(set(index_contact_1))
81-
index_contact_2 = sorted(set(index_contact_2))
78+
if chain2 not in index_contact:
79+
index_contact[chain2] = []
8280

83-
# if no atoms were found
84-
if len(index_contact_1)==0:
85-
print('Warning : No contact atoms detected in pdb2sql')
81+
for i,x0 in enumerate(xyz1):
8682

87-
# extend the list to entire residue
88-
if extend_to_residue:
89-
index_contact_1,index_contact_2 = self._extend_contact_to_residue(index_contact_1,index_contact_2,only_backbone_atoms)
83+
# compute the contact atoms
84+
contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) <= cutoff )[0]
9085

86+
# exclude the H if required
87+
if excludeH and atName1[i][0] == 'H':
88+
continue
9189

92-
# filter only the backbone atoms
93-
if return_only_backbone_atoms and not only_backbone_atoms:
90+
if len(contacts)>0 and any([not only_backbone_atoms, atName1[i] in self.backbone_type]):
9491

95-
# get all the names
96-
# there are better ways to do that !
97-
atNames = np.array(super().get('name'))
92+
# the contact atoms
93+
index_contact[chain1] += [index[chain1][i]]
94+
index_contact[chain2] += [index[chain2][k] for k in contacts if ( any( [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0]=='H') ) ]
9895

99-
# change the index_contacts
100-
index_contact_1 = [ ind for ind in index_contact_1 if atNames[ind] in self.backbone_type ]
101-
index_contact_2 = [ ind for ind in index_contact_2 if atNames[ind] in self.backbone_type ]
96+
# the pairs
97+
pairs = [index[chain2][k] for k in contacts if any( [atName2[k] in self.backbone_type, not only_backbone_atoms] ) and not (excludeH and atName2[k][0]=='H') ]
98+
if len(pairs) > 0:
99+
index_contact_pairs[index[chain1][i]] = pairs
102100

103-
# change the contact pairs
104-
tmp_dict = {}
105-
for ind1,ind2_list in index_contact_pairs.items():
101+
# get uniques
102+
for chain in chainIDs:
103+
index_contact[chain] = sorted(set(index_contact[chain]))
106104

107-
if atNames[ind1] in self.backbone_type:
108-
tmp_dict[index1[ind1]] = [ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type]
109105

110-
index_contact_pairs = tmp_dict
106+
# if no atoms were found
107+
if len(index_contact_pairs)==0:
108+
print('Warning : No contact atoms detected in pdb2sql')
111109

112-
# not sure that's the best way of dealing with that
113-
if return_contact_pairs:
114-
return index_contact_pairs
115-
else:
116-
return index_contact_1,index_contact_2
110+
# extend the list to entire residue
111+
if extend_to_residue:
112+
for chain in chainIDs:
113+
index_contact[chain] = self._extend_contact_to_residue(index_contact_1,only_backbone_atoms)
114+
#index_contact_1,index_contact_2 = self._extend_contact_to_residue(index_contact_1,index_contact_2,only_backbone_atoms)
117115

118-
# extend the contact atoms to the residue
119-
def _extend_contact_to_residue(self,index1,index2,only_backbone_atoms):
120116

121-
# extract the data
122-
dataA = super().get('chainID,resName,resSeq',rowID=index1)
123-
dataB = super().get('chainID,resName,resSeq',rowID=index2)
117+
# filter only the backbone atoms
118+
if return_only_backbone_atoms and not only_backbone_atoms:
124119

125-
# create tuple cause we want to hash through it
126-
dataA = list(map(lambda x: tuple(x),dataA))
127-
dataB = list(map(lambda x: tuple(x),dataB))
120+
# get all the names
121+
# there are better ways to do that !
122+
atNames = np.array(super().get('name'))
128123

129-
# extract uniques
130-
resA = list(set(dataA))
131-
resB = list(set(dataB))
124+
# change the index_contacts
125+
for chain in chainIDs:
126+
index_contact[chain] = [ ind for ind in index_contact[chain] if atNames[ind] in self.backbone_type ]
132127

133-
# init the list
134-
index_contact_A,index_contact_B = [],[]
135128

136-
# contact of chain A
137-
for resdata in resA:
138-
chainID,resName,resSeq = resdata
129+
# change the contact pairs
130+
tmp_dict = {}
131+
for ind1,ind2_list in index_contact_pairs.items():
139132

140-
if only_backbone_atoms:
141-
index = super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
142-
name = super().get('name',chainID=chainID,resName=resName,resSeq=resSeq)
143-
index_contact_A += [ ind for ind,n in zip(index,name) if n in self.backbone_type ]
144-
else:
145-
index_contact_A += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
133+
if atNames[ind1] in self.backbone_type:
134+
tmp_dict[ind1] = [ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type]
146135

147-
# contact of chain B
148-
for resdata in resB:
149-
chainID,resName,resSeq = resdata
150-
if only_backbone_atoms:
151-
index = self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
152-
name = self.get('name',chainID=chainID,resName=resName,resSeq=resSeq)
153-
index_contact_B += [ ind for ind,n in zip(index,name) if n in self.backbone_type ]
154-
else:
155-
index_contact_B += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
136+
index_contact_pairs = tmp_dict
156137

157-
# make sure that we don't have double (maybe optional)
158-
index_contact_A = sorted(set(index_contact_A))
159-
index_contact_B = sorted(set(index_contact_B))
138+
# not sure that's the best way of dealing with that
139+
if return_contact_pairs:
140+
return index_contact_pairs
141+
else:
142+
return index_contact
160143

161-
return index_contact_A,index_contact_B
144+
# extend the contact atoms to the residue
145+
def _extend_contact_to_residue(self,index1,only_backbone_atoms):
162146

147+
# extract the data
148+
dataA = super().get('chainID,resName,resSeq',rowID=index1)
149+
#dataB = super().get('chainID,resName,resSeq',rowID=index2)
163150

164-
# get the contact residue
165-
def get_contact_residues(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False,
166-
only_backbone_atoms=False,return_contact_pairs=False):
151+
# create tuple cause we want to hash through it
152+
dataA = list(map(lambda x: tuple(x),dataA))
153+
#dataB = list(map(lambda x: tuple(x),dataB))
167154

168-
# get the contact atoms
169-
if return_contact_pairs:
155+
# extract uniques
156+
resA = list(set(dataA))
157+
#resB = list(set(dataB))
170158

171-
# declare the dict
172-
residue_contact_pairs = {}
159+
# init the list
160+
index_contact_A = []
173161

174-
# get the contact atom pairs
175-
atom_pairs = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2,
176-
only_backbone_atoms=only_backbone_atoms,
177-
excludeH=excludeH,
178-
return_contact_pairs=True)
162+
# contact of chain A
163+
for resdata in resA:
164+
chainID,resName,resSeq = resdata
179165

180-
# loop over the atom pair dict
181-
for iat1,atoms2 in atom_pairs.items():
166+
if only_backbone_atoms:
167+
index = super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
168+
name = super().get('name',chainID=chainID,resName=resName,resSeq=resSeq)
169+
index_contact_A += [ ind for ind,n in zip(index,name) if n in self.backbone_type ]
170+
else:
171+
index_contact_A += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
182172

183-
# get the res info of the current atom
184-
data1 = tuple(super().get('chainID,resSeq,resName',rowID=[iat1])[0])
173+
# # contact of chain B
174+
# for resdata in resB:
175+
# chainID,resName,resSeq = resdata
176+
# if only_backbone_atoms:
177+
# index = self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
178+
# name = self.get('name',chainID=chainID,resName=resName,resSeq=resSeq)
179+
# index_contact_B += [ ind for ind,n in zip(index,name) if n in self.backbone_type ]
180+
# else:
181+
# index_contact_B += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq)
185182

186-
# create a new entry in the dict if necessary
187-
if data1 not in residue_contact_pairs:
188-
residue_contact_pairs[data1] = set()
183+
# make sure that we don't have double (maybe optional)
184+
index_contact_A = sorted(set(index_contact_A))
185+
#index_contact_B = sorted(set(index_contact_B))
189186

190-
# get the res info of the atom in the other chain
191-
data2 = super().get('chainID,resSeq,resName',rowID=atoms2)
187+
return index_contact_A #,index_contact_B
192188

193-
# store that in the dict without double
194-
for resData in data2:
195-
residue_contact_pairs[data1].add(tuple(resData))
196189

197-
for resData in residue_contact_pairs.keys():
198-
residue_contact_pairs[resData] = sorted(residue_contact_pairs[resData])
190+
# get the contact residue
191+
def get_contact_residues(self,cutoff=8.5,allchains=False,chain1='A',chain2='B',excludeH=False,
192+
only_backbone_atoms=False,return_contact_pairs=False):
199193

200-
return residue_contact_pairs
194+
# get the contact atoms
195+
if return_contact_pairs:
201196

202-
else:
197+
# declare the dict
198+
residue_contact_pairs = {}
203199

204-
# get the contact atoms
205-
contact_atoms = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2,return_contact_pairs=False)
200+
# get the contact atom pairs
201+
atom_pairs = self.get_contact_atoms(cutoff=cutoff,allchains=allchains,chain1=chain1,chain2=chain2,
202+
only_backbone_atoms=only_backbone_atoms,
203+
excludeH=excludeH,
204+
return_contact_pairs=True)
206205

207-
# get the residue info
208-
data1 = super().get('chainID,resSeq,resName',rowID=contact_atoms[0])
209-
data2 = super().get('chainID,resSeq,resName',rowID=contact_atoms[1])
206+
# loop over the atom pair dict
207+
for iat1,atoms2 in atom_pairs.items():
210208

211-
# take only unique
212-
residue_contact_A = sorted(set([tuple(resData) for resData in data1]))
213-
residue_contact_B = sorted(set([tuple(resData) for resData in data2]))
209+
# get the res info of the current atom
210+
data1 = tuple(super().get('chainID,resSeq,resName',rowID=[iat1])[0])
214211

215-
return residue_contact_A,residue_contact_B
212+
# create a new entry in the dict if necessary
213+
if data1 not in residue_contact_pairs:
214+
residue_contact_pairs[data1] = set()
215+
216+
# get the res info of the atom in the other chain
217+
data2 = super().get('chainID,resSeq,resName',rowID=atoms2)
218+
219+
# store that in the dict without double
220+
for resData in data2:
221+
residue_contact_pairs[data1].add(tuple(resData))
222+
223+
for resData in residue_contact_pairs.keys():
224+
residue_contact_pairs[resData] = sorted(residue_contact_pairs[resData])
225+
226+
return residue_contact_pairs
227+
228+
else:
229+
230+
# get the contact atoms
231+
contact_atoms = self.get_contact_atoms(cutoff=cutoff,allchains=allchains,
232+
chain1=chain1,chain2=chain2,
233+
return_contact_pairs=False)
234+
235+
# get the residue info
236+
data = dict()
237+
residue_contact = dict()
238+
239+
for chain in contact_atoms.keys():
240+
data[chain] = super().get('chainID,resSeq,resName',rowID=contact_atoms[chain])
241+
residue_contact[chain] = sorted(set([tuple(resData) for resData in data[chain]]))
242+
243+
244+
return residue_contact
216245

217246

0 commit comments

Comments
 (0)