diff --git a/anesthetic/plot.py b/anesthetic/plot.py index 561fa2da..05bedacf 100644 --- a/anesthetic/plot.py +++ b/anesthetic/plot.py @@ -394,8 +394,12 @@ def _set_logticks(self): if ax is not None: if x in self._logx: ax.xaxis.set_major_locator(LogLocator(numticks=3)) + if x != y: + ax.set_xlim(ax.dataLim.intervalx) if y in self._logy: ax.yaxis.set_major_locator(LogLocator(numticks=3)) + if y != x: + ax.set_ylim(ax.dataLim.intervaly) @staticmethod def _set_labels(axes, labels, **kwargs): diff --git a/anesthetic/samples.py b/anesthetic/samples.py index 605d5fbe..d0dee2b1 100644 --- a/anesthetic/samples.py +++ b/anesthetic/samples.py @@ -356,7 +356,7 @@ def plot_2d(self, axes=None, *args, **kwargs): if np.isinf(self[x]).any(): warnings.warn(f"column {y} has inf values.") selfxy = self[[x, y]] - selfxy = self.replace([-np.inf, np.inf], np.nan) + selfxy = selfxy.replace([-np.inf, np.inf], np.nan) selfxy = selfxy.dropna(axis=0) selfxy.plot(x, y, ax=ax, xlabel=xlabel, logx=x in logx, logy=y in logy, diff --git a/tests/test_samples.py b/tests/test_samples.py index bea15c4f..7addabe8 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -526,13 +526,16 @@ def test_plot_logscale_2d(kind): def test_logscale_ticks(): np.random.seed(42) ndim = 5 - data = np.exp(10 * np.random.randn(200, ndim)) + data1 = np.exp(10 * np.random.randn(200, ndim)) + data2 = np.exp(10 * np.random.randn(200, ndim) - 50) params = [f'a{i}' for i in range(ndim)] fig, axes = make_2d_axes(params, logx=params, logy=params, upper=False) - samples = Samples(data, columns=params) - samples.plot_2d(axes) - for _, col in axes.iterrows(): - for _, ax in col.items(): + samples1 = Samples(data1, columns=params) + samples2 = Samples(data2, columns=params) + samples1.plot_2d(axes) + samples2.plot_2d(axes) + for y, col in axes.iterrows(): + for x, ax in col.items(): if ax is not None: xlims = ax.get_xlim() xticks = ax.get_xticks() @@ -540,6 +543,14 @@ def test_logscale_ticks(): ylims = ax.get_ylim() yticks = ax.get_yticks() assert np.sum((yticks > ylims[0]) & (yticks < ylims[1])) > 1 + if x == y: + data_min = ax.twin.dataLim.intervalx[0] + data_max = ax.twin.dataLim.intervalx[1] + assert xlims[0] == pytest.approx(data_min, rel=1e-14) + assert xlims[1] == pytest.approx(data_max, rel=1e-14) + else: + assert_array_equal(xlims, ax.dataLim.intervalx) + assert_array_equal(ylims, ax.dataLim.intervaly) @pytest.mark.parametrize('k', ['hist_1d', 'hist'])