|
1 | 1 |
|
2 | 2 | import numpy as np |
| 3 | +import itertools |
3 | 4 | from .pdb2sqlcore import pdb2sql |
4 | 5 |
|
5 | 6 | #from pdb2sqlAlchemy import pdb2sql_alchemy as pdb2sql |
|
18 | 19 |
|
19 | 20 | class interface(pdb2sql): |
20 | 21 |
|
21 | | - def __init__(self,pdb): |
| 22 | + def __init__(self,pdb): |
22 | 23 |
|
23 | | - super().__init__(pdb,no_extra=True) |
24 | | - super()._create_sql() |
25 | | - self.backbone_type = ['CA','C','N','O'] |
| 24 | + super().__init__(pdb,no_extra=True) |
| 25 | + super()._create_sql() |
| 26 | + self.backbone_type = ['CA','C','N','O'] |
26 | 27 |
|
27 | | - ############################################################################ |
28 | | - # |
29 | | - # get the contact atoms |
30 | | - # |
31 | | - ############################################################################# |
| 28 | + ############################################################################ |
| 29 | + # |
| 30 | + # get the contact atoms |
| 31 | + # |
| 32 | + ############################################################################# |
32 | 33 |
|
33 | | - def get_contact_atoms(self,cutoff=8.5,chain1='A',chain2='B', |
34 | | - extend_to_residue=False,only_backbone_atoms=False, |
35 | | - excludeH=False,return_only_backbone_atoms=False,return_contact_pairs=False): |
| 34 | + def get_contact_atoms(self,cutoff=8.5,allchains=False,chain1='A',chain2='B', |
| 35 | + extend_to_residue=False,only_backbone_atoms=False, |
| 36 | + excludeH=False,return_only_backbone_atoms=False,return_contact_pairs=False): |
36 | 37 |
|
37 | | - # xyz of the chains |
38 | | - xyz1 = np.array(super().get('x,y,z',chainID=chain1)) |
39 | | - xyz2 = np.array(super().get('x,y,z',chainID=chain2)) |
| 38 | + if allchains: |
| 39 | + chainIDs = super().get_chains() |
| 40 | + else: |
| 41 | + chainIDs = [chain1,chain2] |
| 42 | + nchains = len(chainIDs) |
40 | 43 |
|
41 | | - # index of b |
42 | | - index1 = super().get('rowID',chainID=chain1) |
43 | | - index2 = super().get('rowID',chainID=chain2) |
| 44 | + xyz = dict() |
| 45 | + index = dict() |
| 46 | + resName = dict() |
| 47 | + atName = dict() |
44 | 48 |
|
45 | | - # resName of the chains |
46 | | - resName1 = np.array(super().get('resName',chainID=chain1)) |
47 | | - resName2 = np.array(super().get('resName',chainID=chain2)) |
| 49 | + for chain in chainIDs: |
48 | 50 |
|
49 | | - # atomnames of the chains |
50 | | - atName1 = np.array(super().get('name',chainID=chain1)) |
51 | | - atName2 = np.array(super().get('name',chainID=chain2)) |
| 51 | + data = np.array(super().get('x,y,z,rowID,resName,name',chainID=chain)) |
| 52 | + xyz[chain] = data[:,:3].astype(float) |
| 53 | + index[chain] = data[:,3].astype(int) |
| 54 | + resName[chain] = data[:,-2] |
| 55 | + atName[chain] = data[:,-1] |
52 | 56 |
|
53 | 57 |
|
54 | | - # loop through the first chain |
55 | | - # TO DO : loop through the smallest chain instead ... |
56 | | - index_contact_1,index_contact_2 = [],[] |
57 | | - index_contact_pairs = {} |
| 58 | + # loop through the first chain |
| 59 | + # TO DO : loop through the smallest chain instead ... |
| 60 | + index_contact_1,index_contact_2 = [],[] |
| 61 | + index_contact_pairs = {} |
58 | 62 |
|
59 | | - for i,x0 in enumerate(xyz1): |
| 63 | + index_contact = dict() |
| 64 | + index_contact_pairs = {} |
60 | 65 |
|
61 | | - # compute the contact atoms |
62 | | - contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) <= cutoff )[0] |
| 66 | + for chain1,chain2 in itertools.combinations(chainIDs,2): |
63 | 67 |
|
64 | | - # exclude the H if required |
65 | | - if excludeH and atName1[i][0] == 'H': |
66 | | - continue |
| 68 | + xyz1 = xyz[chain1] |
| 69 | + xyz2 = xyz[chain2] |
67 | 70 |
|
68 | | - if len(contacts)>0 and any([not only_backbone_atoms, atName1[i] in self.backbone_type]): |
| 71 | + atName1 = atName[chain1] |
| 72 | + atName2 = atName[chain2] |
69 | 73 |
|
70 | | - # the contact atoms |
71 | | - index_contact_1 += [index1[i]] |
72 | | - index_contact_2 += [index2[k] for k in contacts if ( any( [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0]=='H') ) ] |
73 | 74 |
|
74 | | - # the pairs |
75 | | - pairs = [index2[k] for k in contacts if any( [atName2[k] in self.backbone_type, not only_backbone_atoms] ) and not (excludeH and atName2[k][0]=='H') ] |
76 | | - if len(pairs) > 0: |
77 | | - index_contact_pairs[index1[i]] = pairs |
| 75 | + if chain1 not in index_contact: |
| 76 | + index_contact[chain1] = [] |
78 | 77 |
|
79 | | - # get uniques |
80 | | - index_contact_1 = sorted(set(index_contact_1)) |
81 | | - index_contact_2 = sorted(set(index_contact_2)) |
| 78 | + if chain2 not in index_contact: |
| 79 | + index_contact[chain2] = [] |
82 | 80 |
|
83 | | - # if no atoms were found |
84 | | - if len(index_contact_1)==0: |
85 | | - print('Warning : No contact atoms detected in pdb2sql') |
| 81 | + for i,x0 in enumerate(xyz1): |
86 | 82 |
|
87 | | - # extend the list to entire residue |
88 | | - if extend_to_residue: |
89 | | - index_contact_1,index_contact_2 = self._extend_contact_to_residue(index_contact_1,index_contact_2,only_backbone_atoms) |
| 83 | + # compute the contact atoms |
| 84 | + contacts = np.where(np.sqrt(np.sum((xyz2-x0)**2,1)) <= cutoff )[0] |
90 | 85 |
|
| 86 | + # exclude the H if required |
| 87 | + if excludeH and atName1[i][0] == 'H': |
| 88 | + continue |
91 | 89 |
|
92 | | - # filter only the backbone atoms |
93 | | - if return_only_backbone_atoms and not only_backbone_atoms: |
| 90 | + if len(contacts)>0 and any([not only_backbone_atoms, atName1[i] in self.backbone_type]): |
94 | 91 |
|
95 | | - # get all the names |
96 | | - # there are better ways to do that ! |
97 | | - atNames = np.array(super().get('name')) |
| 92 | + # the contact atoms |
| 93 | + index_contact[chain1] += [index[chain1][i]] |
| 94 | + index_contact[chain2] += [index[chain2][k] for k in contacts if ( any( [atName2[k] in self.backbone_type, not only_backbone_atoms]) and not (excludeH and atName2[k][0]=='H') ) ] |
98 | 95 |
|
99 | | - # change the index_contacts |
100 | | - index_contact_1 = [ ind for ind in index_contact_1 if atNames[ind] in self.backbone_type ] |
101 | | - index_contact_2 = [ ind for ind in index_contact_2 if atNames[ind] in self.backbone_type ] |
| 96 | + # the pairs |
| 97 | + pairs = [index[chain2][k] for k in contacts if any( [atName2[k] in self.backbone_type, not only_backbone_atoms] ) and not (excludeH and atName2[k][0]=='H') ] |
| 98 | + if len(pairs) > 0: |
| 99 | + index_contact_pairs[index[chain1][i]] = pairs |
102 | 100 |
|
103 | | - # change the contact pairs |
104 | | - tmp_dict = {} |
105 | | - for ind1,ind2_list in index_contact_pairs.items(): |
| 101 | + # get uniques |
| 102 | + for chain in chainIDs: |
| 103 | + index_contact[chain] = sorted(set(index_contact[chain])) |
106 | 104 |
|
107 | | - if atNames[ind1] in self.backbone_type: |
108 | | - tmp_dict[index1[ind1]] = [ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type] |
109 | 105 |
|
110 | | - index_contact_pairs = tmp_dict |
| 106 | + # if no atoms were found |
| 107 | + if len(index_contact_pairs)==0: |
| 108 | + print('Warning : No contact atoms detected in pdb2sql') |
111 | 109 |
|
112 | | - # not sure that's the best way of dealing with that |
113 | | - if return_contact_pairs: |
114 | | - return index_contact_pairs |
115 | | - else: |
116 | | - return index_contact_1,index_contact_2 |
| 110 | + # extend the list to entire residue |
| 111 | + if extend_to_residue: |
| 112 | + for chain in chainIDs: |
| 113 | + index_contact[chain] = self._extend_contact_to_residue(index_contact_1,only_backbone_atoms) |
| 114 | + #index_contact_1,index_contact_2 = self._extend_contact_to_residue(index_contact_1,index_contact_2,only_backbone_atoms) |
117 | 115 |
|
118 | | - # extend the contact atoms to the residue |
119 | | - def _extend_contact_to_residue(self,index1,index2,only_backbone_atoms): |
120 | 116 |
|
121 | | - # extract the data |
122 | | - dataA = super().get('chainID,resName,resSeq',rowID=index1) |
123 | | - dataB = super().get('chainID,resName,resSeq',rowID=index2) |
| 117 | + # filter only the backbone atoms |
| 118 | + if return_only_backbone_atoms and not only_backbone_atoms: |
124 | 119 |
|
125 | | - # create tuple cause we want to hash through it |
126 | | - dataA = list(map(lambda x: tuple(x),dataA)) |
127 | | - dataB = list(map(lambda x: tuple(x),dataB)) |
| 120 | + # get all the names |
| 121 | + # there are better ways to do that ! |
| 122 | + atNames = np.array(super().get('name')) |
128 | 123 |
|
129 | | - # extract uniques |
130 | | - resA = list(set(dataA)) |
131 | | - resB = list(set(dataB)) |
| 124 | + # change the index_contacts |
| 125 | + for chain in chainIDs: |
| 126 | + index_contact[chain] = [ ind for ind in index_contact[chain] if atNames[ind] in self.backbone_type ] |
132 | 127 |
|
133 | | - # init the list |
134 | | - index_contact_A,index_contact_B = [],[] |
135 | 128 |
|
136 | | - # contact of chain A |
137 | | - for resdata in resA: |
138 | | - chainID,resName,resSeq = resdata |
| 129 | + # change the contact pairs |
| 130 | + tmp_dict = {} |
| 131 | + for ind1,ind2_list in index_contact_pairs.items(): |
139 | 132 |
|
140 | | - if only_backbone_atoms: |
141 | | - index = super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
142 | | - name = super().get('name',chainID=chainID,resName=resName,resSeq=resSeq) |
143 | | - index_contact_A += [ ind for ind,n in zip(index,name) if n in self.backbone_type ] |
144 | | - else: |
145 | | - index_contact_A += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
| 133 | + if atNames[ind1] in self.backbone_type: |
| 134 | + tmp_dict[ind1] = [ind2 for ind2 in ind2_list if atNames[ind2] in self.backbone_type] |
146 | 135 |
|
147 | | - # contact of chain B |
148 | | - for resdata in resB: |
149 | | - chainID,resName,resSeq = resdata |
150 | | - if only_backbone_atoms: |
151 | | - index = self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
152 | | - name = self.get('name',chainID=chainID,resName=resName,resSeq=resSeq) |
153 | | - index_contact_B += [ ind for ind,n in zip(index,name) if n in self.backbone_type ] |
154 | | - else: |
155 | | - index_contact_B += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
| 136 | + index_contact_pairs = tmp_dict |
156 | 137 |
|
157 | | - # make sure that we don't have double (maybe optional) |
158 | | - index_contact_A = sorted(set(index_contact_A)) |
159 | | - index_contact_B = sorted(set(index_contact_B)) |
| 138 | + # not sure that's the best way of dealing with that |
| 139 | + if return_contact_pairs: |
| 140 | + return index_contact_pairs |
| 141 | + else: |
| 142 | + return index_contact |
160 | 143 |
|
161 | | - return index_contact_A,index_contact_B |
| 144 | + # extend the contact atoms to the residue |
| 145 | + def _extend_contact_to_residue(self,index1,only_backbone_atoms): |
162 | 146 |
|
| 147 | + # extract the data |
| 148 | + dataA = super().get('chainID,resName,resSeq',rowID=index1) |
| 149 | + #dataB = super().get('chainID,resName,resSeq',rowID=index2) |
163 | 150 |
|
164 | | - # get the contact residue |
165 | | - def get_contact_residues(self,cutoff=8.5,chain1='A',chain2='B',excludeH=False, |
166 | | - only_backbone_atoms=False,return_contact_pairs=False): |
| 151 | + # create tuple cause we want to hash through it |
| 152 | + dataA = list(map(lambda x: tuple(x),dataA)) |
| 153 | + #dataB = list(map(lambda x: tuple(x),dataB)) |
167 | 154 |
|
168 | | - # get the contact atoms |
169 | | - if return_contact_pairs: |
| 155 | + # extract uniques |
| 156 | + resA = list(set(dataA)) |
| 157 | + #resB = list(set(dataB)) |
170 | 158 |
|
171 | | - # declare the dict |
172 | | - residue_contact_pairs = {} |
| 159 | + # init the list |
| 160 | + index_contact_A = [] |
173 | 161 |
|
174 | | - # get the contact atom pairs |
175 | | - atom_pairs = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2, |
176 | | - only_backbone_atoms=only_backbone_atoms, |
177 | | - excludeH=excludeH, |
178 | | - return_contact_pairs=True) |
| 162 | + # contact of chain A |
| 163 | + for resdata in resA: |
| 164 | + chainID,resName,resSeq = resdata |
179 | 165 |
|
180 | | - # loop over the atom pair dict |
181 | | - for iat1,atoms2 in atom_pairs.items(): |
| 166 | + if only_backbone_atoms: |
| 167 | + index = super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
| 168 | + name = super().get('name',chainID=chainID,resName=resName,resSeq=resSeq) |
| 169 | + index_contact_A += [ ind for ind,n in zip(index,name) if n in self.backbone_type ] |
| 170 | + else: |
| 171 | + index_contact_A += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
182 | 172 |
|
183 | | - # get the res info of the current atom |
184 | | - data1 = tuple(super().get('chainID,resSeq,resName',rowID=[iat1])[0]) |
| 173 | + # # contact of chain B |
| 174 | + # for resdata in resB: |
| 175 | + # chainID,resName,resSeq = resdata |
| 176 | + # if only_backbone_atoms: |
| 177 | + # index = self.get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
| 178 | + # name = self.get('name',chainID=chainID,resName=resName,resSeq=resSeq) |
| 179 | + # index_contact_B += [ ind for ind,n in zip(index,name) if n in self.backbone_type ] |
| 180 | + # else: |
| 181 | + # index_contact_B += super().get('rowID',chainID=chainID,resName=resName,resSeq=resSeq) |
185 | 182 |
|
186 | | - # create a new entry in the dict if necessary |
187 | | - if data1 not in residue_contact_pairs: |
188 | | - residue_contact_pairs[data1] = set() |
| 183 | + # make sure that we don't have double (maybe optional) |
| 184 | + index_contact_A = sorted(set(index_contact_A)) |
| 185 | + #index_contact_B = sorted(set(index_contact_B)) |
189 | 186 |
|
190 | | - # get the res info of the atom in the other chain |
191 | | - data2 = super().get('chainID,resSeq,resName',rowID=atoms2) |
| 187 | + return index_contact_A #,index_contact_B |
192 | 188 |
|
193 | | - # store that in the dict without double |
194 | | - for resData in data2: |
195 | | - residue_contact_pairs[data1].add(tuple(resData)) |
196 | 189 |
|
197 | | - for resData in residue_contact_pairs.keys(): |
198 | | - residue_contact_pairs[resData] = sorted(residue_contact_pairs[resData]) |
| 190 | + # get the contact residue |
| 191 | + def get_contact_residues(self,cutoff=8.5,allchains=False,chain1='A',chain2='B',excludeH=False, |
| 192 | + only_backbone_atoms=False,return_contact_pairs=False): |
199 | 193 |
|
200 | | - return residue_contact_pairs |
| 194 | + # get the contact atoms |
| 195 | + if return_contact_pairs: |
201 | 196 |
|
202 | | - else: |
| 197 | + # declare the dict |
| 198 | + residue_contact_pairs = {} |
203 | 199 |
|
204 | | - # get the contact atoms |
205 | | - contact_atoms = self.get_contact_atoms(cutoff=cutoff,chain1=chain1,chain2=chain2,return_contact_pairs=False) |
| 200 | + # get the contact atom pairs |
| 201 | + atom_pairs = self.get_contact_atoms(cutoff=cutoff,allchains=allchains,chain1=chain1,chain2=chain2, |
| 202 | + only_backbone_atoms=only_backbone_atoms, |
| 203 | + excludeH=excludeH, |
| 204 | + return_contact_pairs=True) |
206 | 205 |
|
207 | | - # get the residue info |
208 | | - data1 = super().get('chainID,resSeq,resName',rowID=contact_atoms[0]) |
209 | | - data2 = super().get('chainID,resSeq,resName',rowID=contact_atoms[1]) |
| 206 | + # loop over the atom pair dict |
| 207 | + for iat1,atoms2 in atom_pairs.items(): |
210 | 208 |
|
211 | | - # take only unique |
212 | | - residue_contact_A = sorted(set([tuple(resData) for resData in data1])) |
213 | | - residue_contact_B = sorted(set([tuple(resData) for resData in data2])) |
| 209 | + # get the res info of the current atom |
| 210 | + data1 = tuple(super().get('chainID,resSeq,resName',rowID=[iat1])[0]) |
214 | 211 |
|
215 | | - return residue_contact_A,residue_contact_B |
| 212 | + # create a new entry in the dict if necessary |
| 213 | + if data1 not in residue_contact_pairs: |
| 214 | + residue_contact_pairs[data1] = set() |
| 215 | + |
| 216 | + # get the res info of the atom in the other chain |
| 217 | + data2 = super().get('chainID,resSeq,resName',rowID=atoms2) |
| 218 | + |
| 219 | + # store that in the dict without double |
| 220 | + for resData in data2: |
| 221 | + residue_contact_pairs[data1].add(tuple(resData)) |
| 222 | + |
| 223 | + for resData in residue_contact_pairs.keys(): |
| 224 | + residue_contact_pairs[resData] = sorted(residue_contact_pairs[resData]) |
| 225 | + |
| 226 | + return residue_contact_pairs |
| 227 | + |
| 228 | + else: |
| 229 | + |
| 230 | + # get the contact atoms |
| 231 | + contact_atoms = self.get_contact_atoms(cutoff=cutoff,allchains=allchains, |
| 232 | + chain1=chain1,chain2=chain2, |
| 233 | + return_contact_pairs=False) |
| 234 | + |
| 235 | + # get the residue info |
| 236 | + data = dict() |
| 237 | + residue_contact = dict() |
| 238 | + |
| 239 | + for chain in contact_atoms.keys(): |
| 240 | + data[chain] = super().get('chainID,resSeq,resName',rowID=contact_atoms[chain]) |
| 241 | + residue_contact[chain] = sorted(set([tuple(resData) for resData in data[chain]])) |
| 242 | + |
| 243 | + |
| 244 | + return residue_contact |
216 | 245 |
|
217 | 246 |
|
0 commit comments