Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
1. better way to deal with zero template senario
2. removed some deprecated files
  • Loading branch information
FreshAirTonight committed Mar 3, 2022
1 parent 77779a5 commit cc9e1a2
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 1,508 deletions.
121 changes: 57 additions & 64 deletions src/alphafold/data/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def initialize_template_feats(num_templates_, num_res_, is_multimer=False):
'template_aatype': np.zeros([num_templates_, num_res_, 22], np.float32),
'template_all_atom_masks': np.zeros([num_templates_, num_res_, 37], np.float32),
'template_all_atom_positions': np.zeros([num_templates_, num_res_, 37, 3], np.float32),
'template_domain_names': np.empty([num_templates_], dtype=object),
'template_sequence': np.empty([num_templates_], dtype=object),
'template_domain_names': np.empty([num_templates_], dtype=str),
'template_sequence': np.empty([num_templates_], dtype=str),
'template_sum_probs': np.zeros([num_templates_,1], np.float32),
}

Expand Down Expand Up @@ -100,6 +100,10 @@ def load_monomer_feature(target, flags):
mono_feature_dict["template_sequence"] = mono_feature_dict["template_sequence"][:flags.max_template_hits]
mono_feature_dict["template_sum_probs"] = mono_feature_dict["template_sum_probs"][:flags.max_template_hits,:]

if T == 0 or flags.no_template: # deal with senario no template found, or set it to a null template if requested
mono_template_features = initialize_template_feats(1, L, is_multimer=False)
mono_feature_dict.update(mono_template_features)

if is_multimer_np:
for i in range(monomer["copy_number"]):
f_dict = pipeline_multimer.convert_monomer_features_af2complex(
Expand Down Expand Up @@ -418,9 +422,9 @@ def extract_domain_mono(mono_entry):
mono_msa = np.concatenate((mono_msa, msa[:,sta:end]), axis=1)
mono_mtx = np.concatenate((mono_mtx, mtx[:,sta:end]), axis=1)
mono_res = np.concatenate((mono_res, resid[sta:end]), axis=0)
all_gap_rows = ~np.all(mono_msa == 21, axis=1) ## remove ones with gaps only
features['msa']= mono_msa[all_gap_rows]
features['deletion_matrix_int'] = mono_mtx[all_gap_rows]
not_all_gap_rows = ~np.all(mono_msa == 21, axis=1) ## remove ones with gaps only
features['msa']= mono_msa[not_all_gap_rows]
features['deletion_matrix_int'] = mono_mtx[not_all_gap_rows]
mono_entry['feature_dict'] = features
mono_sequence = mono_seq
features['residue_index'] = mono_res
Expand Down Expand Up @@ -473,6 +477,7 @@ def extract_template_domain_mult(mono_entry):
copy_num = mono_entry['copy_number']
dom_range = mono_entry['domain_range']
mono_sequence = features['sequence'][0].decode()

if dom_range is not None:
mono_seq = ''
for idx, boundary in enumerate(dom_range):
Expand Down Expand Up @@ -649,34 +654,28 @@ def template_cropping_and_joining_mono(curr_input):
monomers = curr_input['monomers']
new_feature_dict = curr_input['new_feature_dict']

new_tem = initialize_template_feats(full_num_tem, full_num_res, False)
col = 0; row = 0
if not flags.no_template:
new_tem = initialize_template_feats(full_num_tem, full_num_res, False)
for mono_entry in monomers:
features = extract_template_domain_mono(mono_entry)

copy_num = mono_entry['copy_number']
#num_res = features['template_aatype'].shape[1]
num_res = features['msa'].shape[1]
num_tem = len(features['template_domain_names'])

if num_tem == 0:
dom_fea = initialize_template_feats(1, num_res, is_multimer=False)
else:
dom_fea = features

for i in range(copy_num):
col_ = col + num_res
row_ = row + num_tem

new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions']
new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names']
new_tem['template_sequence'][row:row_] = dom_fea['template_sequence']
new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs']
new_tem['template_aatype'][row:row_,col:col_,:] = dom_fea['template_aatype']
new_tem['template_all_atom_masks'][row:row_,col:col_,:] = dom_fea['template_all_atom_masks']
col = col_; row = row_
new_feature_dict.update(new_tem)
for mono_entry in monomers:
features = extract_template_domain_mono(mono_entry)

copy_num = mono_entry['copy_number']
#num_res = features['template_aatype'].shape[1]
num_res = features['msa'].shape[1]
num_tem = len(features['template_domain_names'])

for i in range(copy_num):
col_ = col + num_res
row_ = row + num_tem

new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions']
new_tem['template_domain_names'][row:row_] = features['template_domain_names']
new_tem['template_sequence'][row:row_] = features['template_sequence']
new_tem['template_sum_probs'][row:row_] = features['template_sum_probs']
new_tem['template_aatype'][row:row_,col:col_,:] = features['template_aatype']
new_tem['template_all_atom_masks'][row:row_,col:col_,:] = features['template_all_atom_masks']
col = col_; row = row_
new_feature_dict.update(new_tem)

curr_input.update({'new_feature_dict': new_feature_dict})
return curr_input
Expand All @@ -702,39 +701,33 @@ def template_cropping_and_joining_mult(curr_input):
new_feature_dict = curr_input['new_feature_dict']

col = 0; row = 0
if not flags.no_template:
new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True)
for mono_entry in monomers:
features = extract_template_domain_mult(mono_entry)
copy_num = mono_entry['copy_number']

#num_res = features['template_aatype'].shape[1]
num_res = features['msa'].shape[1]
num_tem = len(features['template_domain_names'])
if num_tem == 0:
dom_fea = initialize_template_feats(1, num_res, is_multimer=True)
dom_fea['asym_id'] = features['asym_id']
dom_fea['sym_id'] = features['sym_id']
dom_fea['entity_id'] = features['entity_id']
else:
dom_fea = features

for i in range(copy_num):
col_ = col + num_res
row_ = row + num_tem

new_tem['template_all_atom_positions'][row:row_,col:col_,...] = dom_fea['template_all_atom_positions']
new_tem['template_domain_names'][row:row_] = dom_fea['template_domain_names']
new_tem['template_sequence'][row:row_] = dom_fea['template_sequence']
new_tem['template_sum_probs'][row:row_] = dom_fea['template_sum_probs']
new_tem['template_all_atom_mask'][row:row_,col:col_,:] = dom_fea['template_all_atom_mask']
new_tem['template_aatype'][row:row_,col:col_] = dom_fea['template_aatype']
new_tem['asym_id'][col:col_] = dom_fea['asym_id']
new_tem['sym_id'][col:col_] = dom_fea['sym_id']
new_tem['entity_id'][col:col_] = dom_fea['entity_id']

col = col_; row = row_
new_feature_dict.update(new_tem)
new_tem = initialize_template_feats(full_num_tem, full_num_res, is_multimer=True)
for mono_entry in monomers:
features = extract_template_domain_mult(mono_entry)
copy_num = mono_entry['copy_number']

#num_res = features['template_aatype'].shape[1]
num_res = features['msa'].shape[1]
num_tem = len(features['template_domain_names'])

for i in range(copy_num):
col_ = col + num_res
row_ = row + num_tem

new_tem['template_all_atom_positions'][row:row_,col:col_,...] = features['template_all_atom_positions']
new_tem['template_domain_names'][row:row_] = features['template_domain_names']
new_tem['template_sequence'][row:row_] = features['template_sequence']
new_tem['template_sum_probs'][row:row_] = features['template_sum_probs']
new_tem['template_all_atom_mask'][row:row_,col:col_,:] = features['template_all_atom_mask']
new_tem['template_aatype'][row:row_,col:col_] = features['template_aatype']
new_tem['asym_id'][col:col_] = features['asym_id']
new_tem['sym_id'][col:col_] = features['sym_id']
new_tem['entity_id'][col:col_] = features['entity_id']

col = col_; row = row_

new_feature_dict.update(new_tem)

curr_input.update({'new_feature_dict': new_feature_dict})
return curr_input
Loading

0 comments on commit cc9e1a2

Please sign in to comment.