Skip to content

Commit

Permalink
fix tram bug (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaaikeG authored Feb 16, 2022
1 parent dc320db commit de8768d
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 5 deletions.
2 changes: 1 addition & 1 deletion deeptime/markov/msm/tram/_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _make_tram_estimator(self, model, dataset):
# copy free energies along the markoc state axis to get initial biased_conf_energies
biased_conf_energies = np.repeat(free_energies[:, None], dataset.n_markov_states, axis=1)
else:
biased_conf_energies = np.zeros((dataset.n_markov_states, dataset.n_therm_states))
biased_conf_energies = np.zeros((dataset.n_therm_states, dataset.n_markov_states))

lagrangian_mult_log = tram.initialize_lagrangians(dataset.transition_counts)
modified_state_counts = np.zeros_like(lagrangian_mult_log) # intialize this as the dataset state counts???
Expand Down
2 changes: 1 addition & 1 deletion deeptime/src/include/deeptime/markov/msm/tram/mbar.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ initialize_MBAR(BiasMatrix <dtype> biasMatrix, CountsMatrix stateCounts, std::si
double maxErr = 1e-6, std::size_t callbackInterval = 1, const py::object *callback = nullptr) {
// get dimensions...
auto nThermStates = stateCounts.shape(0);
auto nSamples = biasMatrix.shape(1);
auto nSamples = biasMatrix.shape(0);

// work in log space so compute the log of the statecounts beforehand
auto stateCountsLog = std::vector<dtype>(nThermStates);
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-setup+build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ steps:
conda config --set quiet true
displayName: Configure conda
- bash: |
conda clean --all
conda install mamba
mamba update --all
mamba install boa conda-build conda-verify pip
Expand Down
2 changes: 1 addition & 1 deletion examples/methods/plot_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def sample_trajectories(bias_functions):
dtrajs = clustering.transform(trajectories.flatten()).reshape((len(bias_matrices), n_samples))

from tqdm import tqdm
tram = TRAM(lagtime=1, maxiter=1000, maxerr=1e-3, progress=tqdm)
tram = TRAM(lagtime=1, maxiter=1000, maxerr=1e-3, progress=tqdm, init_strategy="MBAR")

# For every simulation frame seen in trajectory i and time step t, btrajs[i][t,k] is the
# bias energy of that frame evaluated in the k'th thermodynamic state (i.e. at the k'th
Expand Down
17 changes: 15 additions & 2 deletions tests/markov/msm/test_tram.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,24 @@ def test_tqdm_progress_bar():
tram.fit(make_random_input_data(5, 5))


def test_fit_with_dataset():
@pytest.mark.parametrize(
"init_strategy", ["MBAR", None]
)
def test_fit_with_dataset(init_strategy):
dataset = TRAMDataset(dtrajs=[np.asarray([0, 1, 2])], bias_matrices=[np.asarray([[1.], [2.], [3.]])])
tram = TRAM()
tram = TRAM(init_strategy=init_strategy)
tram.fit(dataset)


@pytest.mark.parametrize(
"init_strategy", ["MBAR", None]
)
def test_fit_with_dataset(init_strategy):
input_data = make_random_input_data(20, 2)
tram = TRAM(init_strategy=init_strategy)
tram.fit(input_data)


def test_mbar_initalization():
(dtrajs, bias_matrices) = make_random_input_data(5, 5, make_ttrajs=False)
tram = TRAM(callback_interval=2, maxiter=0, progress=tqdm, init_maxiter=100)
Expand All @@ -296,3 +308,4 @@ def test_mbar_initialization_zero_iterations():
model1 = tram1.fit_fetch(input_data)
model2 = tram2.fit_fetch(input_data)
np.testing.assert_equal(model1.biased_conf_energies, model2.biased_conf_energies)

0 comments on commit de8768d

Please sign in to comment.