@@ -90,7 +90,7 @@ def get(self, s, df=None, invert=False):
9090
9191 Parameters
9292 ----------
93- s : str in {'main chain', 'hydrogen', 'c-alpha'}
93+ s : str in {'main chain', 'hydrogen', 'c-alpha', 'heavy' }
9494 String to specify which entries to return
9595
9696 df : pandas.DataFrame, default: None
@@ -115,8 +115,41 @@ def get(self, s, df=None, invert=False):
115115 df = self ._df ['ATOM' ]
116116 return self ._get_dict [s ](df , invert = invert )
117117
118+ def impute_element (self , sections = ['ATOM' , 'HETATM' ], inplace = False ):
119+ """Impute element_symbol from atom_name section.
120+
121+ Parameters
122+ ----------
123+ sections : iterable (default: ['ATOM', 'HETATM'])
124+ Coordinate sections for which the element symbols should be
125+ imputed.
126+
127+ inplace : bool (default: False)
128+ Performs the operation in-place if True and returns a copy of the
129+ PDB DataFrame otherwise.
130+
131+ Returns
132+ ---------
133+ DataFrame
134+
135+ """
136+ if inplace :
137+ t = self .df
138+ else :
139+ t = self .df .copy ()
140+ for d in self .df :
141+ t [d ] = self .df [d ].copy ()
142+
143+ for sec in sections :
144+ t [sec ]['element_symbol' ] = \
145+ t [sec ][['atom_name' , 'element_symbol' ]].\
146+ apply (lambda x : x [0 ][1 ]
147+ if len (x [1 ]) == 3
148+ else x [0 ][0 ], axis = 1 )
149+ return t
150+
118151 @staticmethod
119- def rmsd (df1 , df2 , s = 'main chain' , invert = False ):
152+ def rmsd (df1 , df2 , s = None , invert = False ):
120153 """Compute the Root Mean Square Deviation between molecules.
121154
122155 Parameters
@@ -128,8 +161,10 @@ def rmsd(df1, df2, s='main chain', invert=False):
128161 Second DataFrame for RMSD computation against df1. Must have the
129162 same number of entries as df1
130163
131- s : str in {'main chain', 'hydrogen', 'c-alpha'}, default: 'main chain'
132- String to specify which entries to consider.
164+ s : {'main chain', 'hydrogen', 'c-alpha', 'heavy', 'carbon'} or None,
165+ default: None
166+ String to specify which entries to consider. If None, considers
167+ all atoms for comparison.
133168
134169 invert : bool, default: False
135170 Inverts the string query if true. For example, the setting
@@ -163,7 +198,9 @@ def _init_get_dict():
163198 """Initialize dictionary for filter operations."""
164199 get_dict = {'main chain' : PandasPDB ._get_mainchain ,
165200 'hydrogen' : PandasPDB ._get_hydrogen ,
166- 'c-alpha' : PandasPDB ._get_calpha }
201+ 'c-alpha' : PandasPDB ._get_calpha ,
202+ 'carbon' : PandasPDB ._get_carbon ,
203+ 'heavy' : PandasPDB ._get_heavy }
167204 return get_dict
168205
169206 @staticmethod
@@ -234,9 +271,17 @@ def _get_mainchain(df, invert):
234271 def _get_hydrogen (df , invert ):
235272 """Return only hydrogen atom entries from a DataFrame"""
236273 if invert :
237- return df [(df ['atom_name' ] != 'H' )]
274+ return df [(df ['element_symbol' ] != 'H' )]
275+ else :
276+ return df [(df ['element_symbol' ] == 'H' )]
277+
278+ @staticmethod
279+ def _get_heavy (df , invert ):
280+ """Return only heavy atom entries from a DataFrame"""
281+ if invert :
282+ return df [df ['element_symbol' ] == 'H' ]
238283 else :
239- return df [( df ['atom_name ' ] == 'H' ) ]
284+ return df [df ['element_symbol ' ] != 'H' ]
240285
241286 @staticmethod
242287 def _get_calpha (df , invert ):
@@ -246,6 +291,14 @@ def _get_calpha(df, invert):
246291 else :
247292 return df [df ['atom_name' ] == 'CA' ]
248293
294+ @staticmethod
295+ def _get_carbon (df , invert ):
296+ """Return c-alpha atom entries from a DataFrame"""
297+ if invert :
298+ return df [df ['element_symbol' ] == 'C' ]
299+ else :
300+ return df [df ['element_symbol' ] != 'C' ]
301+
249302 @staticmethod
250303 def _construct_df (pdb_lines ):
251304 """Construct DataFrames from list of PDB lines."""
@@ -256,7 +309,8 @@ def _construct_df(pdb_lines):
256309 if line .strip ():
257310 if line .startswith (valids ):
258311 record = line [:6 ].rstrip ()
259- line_ele = ['' for _ in range (len (pdb_records [record ])+ 1 )]
312+ line_ele = ['' for _ in range (len (
313+ pdb_records [record ]) + 1 )]
260314 for idx , ele in enumerate (pdb_records [record ]):
261315 line_ele [idx ] = (line [ele ['line' ][0 ]:ele ['line' ][1 ]]
262316 .strip ())
@@ -269,7 +323,7 @@ def _construct_df(pdb_lines):
269323 dfs = {}
270324 for r in line_lists .items ():
271325 df = pd .DataFrame (r [1 ], columns = [c ['id' ] for c in
272- pdb_records [r [0 ]]]+ ['line_idx' ])
326+ pdb_records [r [0 ]]] + ['line_idx' ])
273327 for c in pdb_records [r [0 ]]:
274328 try :
275329 df [c ['id' ]] = df [c ['id' ]].astype (c ['type' ])
0 commit comments