Skip to content

Commit

Permalink
Add task to create other illustrations
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 4, 2024
1 parent 8cbac95 commit dac98cf
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 18 deletions.
245 changes: 239 additions & 6 deletions src/tranquilo_dev/plotting/illustrations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from copy import deepcopy

import estimagic as em
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tranquilo.visualize import _clean_legend_duplicates
from tranquilo.visualize import _get_sample_points


def create_noise_plots():
out = {}
out = []

def func(x):
return x**4 + x**2 - x
Expand Down Expand Up @@ -57,7 +59,7 @@ def model(x, beta):
"orientation": "h",
}
)
out[0] = fig # fig.write_image("noise_plot_0.svg")
out.append(fig) # fig.write_image("noise_plot_0.svg")

fig1 = deepcopy(fig)

Expand All @@ -71,10 +73,10 @@ def model(x, beta):
showlegend=False,
)
)
out[1] = fig1 # fig1.write_image("noise_plot_1.svg")
out.append(fig1) # fig1.write_image("noise_plot_1.svg")

plotting_data = []
for i, trustregion in enumerate([(-1.75, -0.5), (-0.75, 0.5)]):
for _, trustregion in enumerate([(-1.75, -0.5), (-0.75, 0.5)]):
in_region = (trustregion[0] <= noise_x_grid) & (noise_x_grid <= trustregion[1])
xs = noise_x_grid.copy()
ys = noise_y.copy()
Expand Down Expand Up @@ -126,7 +128,7 @@ def model(x, beta):
fig2.update_layout(width=600)

plotting_data.append(fig2.data)
out[2 + i] = fig2 # fig2.write_image(f'noise_plot_{i+2}.svg')
out.append(fig2) # fig2.write_image(f'noise_plot_{i+2}.svg')

len(plotting_data)
fig3 = make_subplots(cols=2, rows=1)
Expand All @@ -151,5 +153,236 @@ def model(x, beta):
linecolor="black",
zeroline=False,
)
out[4] = fig3
out.append(fig3)
return out


def create_other_illustration_plots():
out = {}

def sphere(x):
return {"root_contributions": x}

res = em.minimize(
sphere,
params=np.array([1, 1]),
algorithm="tranquilo_ls",
)
params_history = np.array(res.history["params"])
states = res.algorithm_output["states"]
color_dict = {
"existing": "rgb(0,0,255)",
"new": "rgb(230,0,0)",
"discarded": "rgb(148, 148, 148)",
}

for k in range(6):
out[f"animation_{k}.svg"] = _plot_sample_points(
params_history, states, k, color_dict
)

base_fig = _plot_sample_points(params_history, states, 1, color_dict)
fig = deepcopy(base_fig)
center = np.array([fig.data[0].x[0], fig.data[0].y[0]])
candidate = np.array([fig.data[-1].x[0], fig.data[-1].y[0]])
direction = candidate - center
line_points = []
for i in [2, 4, 8]:
line_points.append(center + i * direction)
line_points = np.array(line_points)
fig.data[1].marker.color = fig.data[0].marker.color
fig.data[1].showlegend = False
fig.add_trace(
go.Scatter(
x=line_points[:, 0],
y=line_points[:, 1],
mode="markers",
marker_color="#805B87",
marker_size=9,
name="line search",
)
)
fig.update_xaxes(range=[0, 1.2])
fig.update_yaxes(range=[0, 1.2])
out["line_points_3.svg"] = deepcopy(fig) # fig.write_image("line_points_3.svg")

fig.data[3].x = fig.data[3].x[:-1]
fig.data[3].y = fig.data[3].y[:-1]
out["line_points_2.svg"] = deepcopy(fig) # fig.write_image("line_points_2.svg")

fig.data[3].x = fig.data[3].x[:-1]
fig.data[3].y = fig.data[3].y[:-1]
out["line_points_1.svg"] = deepcopy(fig) # fig.write_image("line_points_1.svg")

fig.data = fig.data[:-1]
out["origin_plot.svg"] = deepcopy(fig) # fig.write_image("origin_plot.svg")

center = candidate
radius = states[1].trustregion.radius
fig.add_shape(
type="circle",
xref="x",
yref="y",
x0=center[0] - radius,
y0=center[1] - radius,
x1=center[0] + radius,
y1=center[1] + radius,
line_width=0.5,
line_color="black",
line_dash="dash",
)
out["empty_speculative_trustregion_small_scale.svg"] = deepcopy(fig)

fig.update_yaxes(range=[0.5, 1.2])
fig.update_xaxes(range=[0.5, 1.2])
out["empty_speculative_trustregion_large_scale.svg"] = deepcopy(fig)

fig.update_yaxes(range=[0, 1.2])
fig.update_xaxes(range=[0, 1.2])

angle1 = 30
angle2 = 55
speculative1 = np.zeros((2, 2))
speculative1[0, 0] = candidate[0] - np.cos(angle1 * np.pi / 180) * radius
speculative1[0, 1] = candidate[1] - np.sin(angle1 * np.pi / 180) * radius
speculative1[1, 0] = candidate[0] + np.sin(angle2 * np.pi / 180) * radius
speculative1[1, 1] = candidate[1] - np.cos(angle2 * np.pi / 180) * radius
fig.data = fig.data[:3]
fig.add_trace(
go.Scatter(
x=speculative1[:, 0],
y=speculative1[:, 1],
mode="markers",
marker_color="#F98900",
marker_size=9,
name="speculative",
)
)
out["sampled_speculative_trustregion_small_scale.svg"] = deepcopy(fig)

fig.update_yaxes(range=[0.5, 1.2])
fig.update_xaxes(range=[0.5, 1.2])
out["sampled_speculative_trustregion_large_scale.svg"] = deepcopy(fig)

fig.update_yaxes(range=[0, 1.2])
fig.update_xaxes(range=[0, 1.2])
angle1 = 15
angle2 = 78
speculative1[0, 0] = candidate[0] + np.sin(angle1 * np.pi / 180) * radius
speculative1[0, 1] = candidate[1] - np.cos(angle1 * np.pi / 180) * radius
speculative1[1, 0] = candidate[0] + np.sin(angle2 * np.pi / 180) * radius
speculative1[1, 1] = candidate[1] - np.cos(angle2 * np.pi / 180) * radius

fig.data = fig.data[:3]
fig.add_trace(
go.Scatter(
x=speculative1[:, 0],
y=speculative1[:, 1],
mode="markers",
marker_color="#F98900",
marker_size=9,
name="speculative",
)
)
fig.add_trace(
go.Scatter(
x=line_points[:, 0],
y=line_points[:, 1],
mode="markers",
marker_color="#805B87",
marker_size=9,
name="line search",
)
)

out["line_and_speculative_points.svg"] = deepcopy(fig)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(scaleanchor="y", scaleratio=1)

fig.update_layout(height=700, width=800, template="plotly_white")
out["checklayout.svg"] = deepcopy(fig)

return out


# ======================================================================================
# Shared functions
# ======================================================================================


def _plot_sample_points(history, states, iteration, color_dict):

state = states[iteration]
sample_points = _get_sample_points(state, history)
trustregion = state.trustregion
radius = trustregion.radius
center = trustregion.center
fig = go.Figure()
fig.add_shape(
type="circle",
xref="x",
yref="y",
x0=center[0] - radius,
y0=center[1] - radius,
x1=center[0] + radius,
y1=center[1] + radius,
line_width=0.5,
line_color="grey",
)

fig.add_traces(
px.scatter(
sample_points,
x=0,
y=1,
color="case",
color_discrete_map=color_dict,
opacity=0.7,
).data,
)
if iteration >= 1:
fig.add_trace(
go.Scatter(
x=[state.candidate_x[0]],
y=[state.candidate_x[1]],
mode="markers",
marker_symbol="star",
marker_color="#228b22",
name="candidate",
),
)
fig.update_traces(
marker_size=9,
)

fig = _clean_legend_duplicates(fig)
fig.update_layout(height=700, width=800, template="plotly_white")
fig.update_yaxes(
showgrid=False,
showline=True,
linewidth=1,
linecolor="black",
zeroline=False,
range=[-1.6, 1.6],
)
fig.update_xaxes(
showgrid=False,
showline=True,
linewidth=1,
linecolor="black",
zeroline=False,
range=[-1.6, 1.6],
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(scaleanchor="y", scaleratio=1)

layout = go.Layout(
margin=go.layout.Margin(
l=10, # left margin
r=10, # right margin
b=10, # bottom margin
t=10, # top margin
)
)
fig = fig.update_layout(layout)
return fig
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytask
from tranquilo_dev.config import BLD
from tranquilo_dev.plotting.illustrations import create_noise_plots
from tranquilo_dev.plotting.illustrations import create_other_illustration_plots

BLD_SLIDEV = BLD.joinpath("bld_slidev")


@pytask.mark.produces({k: BLD_SLIDEV.joinpath(f"noise_plot_{k}.svg") for k in range(5)})
def task_create_noise_plots(produces):
figures = create_noise_plots()
for path, fig in zip(produces.values(), figures):
fig.write_image(path)


OTHER_ILLUSTRATION_NAMES = [
*[f"animation_{k}.svg" for k in range(6)],
*[f"line_points_{k}.svg" for k in range(1, 4)],
"origin_plot.svg",
"empty_speculative_trustregion_small_scale.svg",
"empty_speculative_trustregion_large_scale.svg",
"sampled_speculative_trustregion_small_scale.svg",
"sampled_speculative_trustregion_large_scale.svg",
"line_and_speculative_points.svg",
"checklayout.svg",
]


@pytask.mark.produces([BLD_SLIDEV.joinpath(name) for name in OTHER_ILLUSTRATION_NAMES])
def task_create_other_illustration_plots(produces):
figures = create_other_illustration_plots()
for path in produces.values():
figures[path.name].write_image(path)

This file was deleted.

0 comments on commit dac98cf

Please sign in to comment.