Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions config/arena/gulf_of_mexico_LongTermAverageSource.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
casadi_cache_dict:
deg_around_x_t: 0.5
time_around_x_t: 172800. # 3600 * 24 * 5

platform_dict:
battery_cap_in_wh: 400.0
u_max_in_mps: 0.1
motor_efficiency: 1.0
solar_panel_size: 0.5
solar_efficiency: 0.2
drag_factor: 675.0
dt_in_s: 600.0

use_geographic_coordinate_system: True

spatial_boundary: null

ocean_dict:
hindcast:
field: 'OceanCurrents'
source: 'hindcast_files'
source_settings:
folder: "data/hindcast_test"
source: "HYCOM"
type: "hindcast"

forecast:
field: 'OceanCurrents'
source: 'longterm_average'
source_settings:
forecast:
field: 'OceanCurrents'
source: 'forecast_files'
source_settings:
folder: "data/forecast_test"
average:
field: 'OceanCurrents'
source: 'hindcast_files'
source_settings:
folder: "data/monthly_average"

solar_dict:
hindcast: null
forecast: null

seaweed_dict:
hindcast: null
forecast: null
8 changes: 3 additions & 5 deletions ocean_navigation_simulator/data_sources/DataSource.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,12 @@ def array_subsetting_sanity_check(array: xr, x_interval: List[float], y_interval
# Step 2: Data partially not in the array check
if array.coords['lat'].data[0] > y_interval[0] or array.coords['lat'].data[-1] < y_interval[1]:
logger.warning(
f"Part of the y requested area is outside of file(file: [{array.coords['lat'].data[0]}, {array.coords['lat'].data[-1]}], requested: [{y_interval[0]}, {y_interval[1]}]).",
RuntimeWarning)
f"Part of the y requested area is outside of file(file: [{array.coords['lat'].data[0]}, {array.coords['lat'].data[-1]}], requested: [{y_interval[0]}, {y_interval[1]}]).")
if array.coords['lon'].data[0] > x_interval[0] or array.coords['lon'].data[-1] < x_interval[1]:
logger.warning(
f"Part of the x requested area is outside of file (file: [{array.coords['lon'].data[0]}, {array.coords['lon'].data[-1]}], requested: [{x_interval[0]}, {x_interval[1]}]).",
RuntimeWarning)
f"Part of the x requested area is outside of file (file: [{array.coords['lon'].data[0]}, {array.coords['lon'].data[-1]}], requested: [{x_interval[0]}, {x_interval[1]}]).")
if units.get_datetime_from_np64(array.coords['time'].data[-1]) < t_interval[1]:
logger.warning("The final time is not part of the subset.".format(t_interval[1]), RuntimeWarning)
logger.warning("The final time {} is not part of the subset.".format(t_interval[1]))

def plot_data_at_time_over_area(self, time: Union[datetime.datetime, float],
x_interval: List[float], y_interval: List[float],
Expand Down
4 changes: 3 additions & 1 deletion ocean_navigation_simulator/data_sources/OceanCurrentField.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ocean_navigation_simulator.data_sources.DataField import DataField
from ocean_navigation_simulator.data_sources.OceanCurrentSource.OceanCurrentSource import OceanCurrentSource
from ocean_navigation_simulator.data_sources.OceanCurrentSource.OceanCurrentSource import \
HindcastFileSource, HindcastOpendapSource, ForecastFileSource
HindcastFileSource, HindcastOpendapSource, ForecastFileSource, LongTermAverageSource
import ocean_navigation_simulator.data_sources.OceanCurrentSource.AnalyticalOceanCurrents as AnalyticalSources


Expand Down Expand Up @@ -47,6 +47,8 @@ def instantiate_source_from_dict(source_dict: Dict) -> OceanCurrentSource:
return HindcastFileSource(source_dict)
elif source_dict['source'] == 'forecast_files':
return ForecastFileSource(source_dict)
elif source_dict['source'] == 'longterm_average':
return LongTermAverageSource(source_dict)
elif source_dict['source'] == 'analytical':
specific_analytical_current = getattr(AnalyticalSources, source_dict['source_settings']['name'])
return specific_analytical_current(source_dict)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime
from multiprocessing.sharedctypes import Value
import os
from typing import List, AnyStr, Optional, Union
import logging
from calendar import month

import casadi as ca
import dask.array.core
Expand Down Expand Up @@ -44,7 +46,6 @@ def initialize_casadi_functions(self, grid: List[List[float]], array: xr) -> Non
grid: list of the 3 grids [time, y_grid, x_grid] for the xr data
array: xarray object containing the sub-setted data for the next cached round
"""

self.u_curr_func = ca.interpolant('u_curr', 'linear', grid, array['water_u'].values.ravel(order='F'))
self.v_curr_func = ca.interpolant('v_curr', 'linear', grid, array['water_v'].values.ravel(order='F'))

Expand Down Expand Up @@ -185,7 +186,7 @@ def get_data_at_point(self, spatio_temporal_point: SpatioTemporalPoint) -> Ocean

class ForecastFileSource(OceanCurrentSourceXarray):
# TODO: Make it work with multiple files for one forecast (a bit of extra logic, but possible)
"""Data Source Object that accesses and manages multiple daily HYCOM files as source."""
"""Data Source Object that accesses and manages multiple daily files as source."""

def __init__(self, source_config_dict: dict):
super().__init__(source_config_dict)
Expand All @@ -209,6 +210,18 @@ def get_data_over_area(self, x_interval: List[float], y_interval: List[float],
spatial_resolution: Optional[float] = None,
temporal_resolution: Optional[float] = None,
most_recent_fmrc_at_time: Optional[datetime.datetime] = None) -> xr:
"""Function to get the the raw current data over an x, y, and t interval.
Args:
x_interval: List of the lower and upper x area in the respective coordinate units [x_lower, x_upper]
y_interval: List of the lower and upper y area in the respective coordinate units [y_lower, y_upper]
t_interval: List of the lower and upper datetime requested [t_0, t_T] in datetime
spatial_resolution: spatial resolution in the same units as x and y interval
temporal_resolution: temporal resolution in seconds
most_recent_fmrc_at_time: if specified this is the idx of a specific forecast file to get data from it
otherwise the most recent fmrc available at t_interval[0] is used.
Returns:
data_array in xarray format that contains the grid and the values (land is NaN)
"""
# format to datetime object
if not isinstance(t_interval[0], datetime.datetime):
t_interval = [datetime.datetime.fromtimestamp(time, tz=datetime.timezone.utc) for time in t_interval]
Expand Down Expand Up @@ -266,7 +279,7 @@ def get_data_at_point(self, spatio_temporal_point: SpatioTemporalPoint) -> Ocean


class HindcastFileSource(OceanCurrentSourceXarray):
"""Data Source Object that accesses and manages one or many HYCOM files as source."""
"""Data Source Object that accesses and manages one or many daily files as source."""

def __init__(self, source_config_dict: dict):
super().__init__(source_config_dict)
Expand All @@ -289,6 +302,7 @@ def get_data_at_point(self, spatio_temporal_point: SpatioTemporalPoint) -> Ocean


class HindcastOpendapSource(OceanCurrentSourceXarray):
"""Data Source Object that accesses the data via the opendap framework directly from the server."""
def __init__(self, source_config_dict: dict):
super().__init__(source_config_dict)
# Step 1: establish the opendap connection with the settings in config_dict
Expand All @@ -310,14 +324,56 @@ def get_data_at_point(self, spatio_temporal_point: SpatioTemporalPoint) -> Ocean
v=self.v_curr_func(spatio_temporal_point.to_spatio_temporal_casadi_input()))


class LongTermAverageSource(OceanCurrentSource):
""""""
def __init__(self, source_config_dict: dict):
self.u_curr_func, self.v_curr_func = [None] * 2
self.forecast_data_source = ForecastFileSource(source_config_dict['source_settings']['forecast'])
self.monthly_avg_data_source = HindcastFileSource(source_config_dict['source_settings']['average']) # defaults currents to normal
self.source_config_dict = source_config_dict
# self.t_0 = source_config_dict['t0'] # not sure what to do here

def get_data_over_area(self, x_interval: List[float], y_interval: List[float],
t_interval: List[Union[datetime.datetime, int]],
spatial_resolution: Optional[float] = None,
temporal_resolution: Optional[float] = None) -> xr:
# Query as much forecast data as is possible
try:
forecast_dataframe = self.forecast_data_source.get_data_over_area(x_interval, y_interval, t_interval, spatial_resolution, temporal_resolution)
end_forecast_time = get_datetime_from_np64(forecast_dataframe["time"].to_numpy()[-1])
except ValueError:
monthly_average_dataframe = self.monthly_avg_data_source.get_data_over_area(x_interval, y_interval, t_interval, spatial_resolution, temporal_resolution)
return monthly_average_dataframe


if end_forecast_time >= t_interval[1]:
return forecast_dataframe

remaining_t_interval = [end_forecast_time, t_interval[1]] # may not work
monthly_average_dataframe = self.monthly_avg_data_source.get_data_over_area(x_interval, y_interval, remaining_t_interval, spatial_resolution, temporal_resolution)
return xr.concat([forecast_dataframe, monthly_average_dataframe], dim="time")

def check_for_most_recent_fmrc_dataframe(self, time: datetime.datetime) -> int:
"""Helper function to check update the self.OceanCurrent if a new forecast is available at
the specified input time.
Args:
time: datetime object
"""
return self.forecast_data_source.check_for_most_recent_fmrc_dataframe(time)

# Not sure if I can just all this
def get_data_at_point(self, spatio_temporal_point: SpatioTemporalPoint) -> OceanCurrentVector:
"""We overwrite it because we don't want that Forecast needs caching..."""
return self.forecast_data_source.get_data_at_point(spatio_temporal_point==spatio_temporal_point)

# Helper functions across the OceanCurrentSource objects
def get_file_dicts(folder: AnyStr, currents='normal') -> List[dict]:
""" Creates an list of dicts ordered according to time available, one for each nc file available in folder.
The dicts for each file contains:
{'t_range': [<datetime object>, T], 'file': <filepath> ,'y_range': [min_lat, max_lat], 'x_range': [min_lon, max_lon]}
"""
# get a list of files from the folder
files_list = [folder + f for f in os.listdir(folder) if
files_list = [folder + '/' + f for f in os.listdir(folder) if
(os.path.isfile(os.path.join(folder, f)) and f != '.DS_Store')]

# iterate over all files to extract the grids and put them in an ordered list of dicts
Expand Down
110 changes: 110 additions & 0 deletions scripts/nisha/avg_data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import datetime
import numpy as np
from tqdm import tqdm

from ocean_navigation_simulator.environment.ArenaFactory import ArenaFactory
from ocean_navigation_simulator.environment.Platform import PlatformState
from ocean_navigation_simulator.environment.PlatformState import SpatialPoint
from ocean_navigation_simulator.environment.NavigationProblem import NavigationProblem
from ocean_navigation_simulator.controllers.hj_planners.HJReach2DPlanner import HJReach2DPlanner
from ocean_navigation_simulator.utils import units
import matplotlib.pyplot as plt

# Initialize the Arena (holds all data sources and the platform, everything except controller)
arena = ArenaFactory.create(scenario_name='gulf_of_mexico_LongTermAverageSource')
# we can also download the respective files directly to a temp folder, then t_interval needs to be set
# % Specify Navigation Problem
x_0 = PlatformState(lon=units.Distance(deg=-82.5), lat=units.Distance(deg=23.7),
date_time=datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc))
x_T = SpatialPoint(lon=units.Distance(deg=-80.3), lat=units.Distance(deg=24.6))

problem = NavigationProblem(
start_state=x_0,
end_region=x_T,
target_radius=0.1,
timeout=datetime.timedelta(days=2),
platform_dict=arena.platform.platform_dict)

# %% Plot the problem on the map
t_interval, lat_bnds, lon_bnds = arena.ocean_field.hindcast_data_source.convert_to_x_y_time_bounds(
x_0=x_0.to_spatio_temporal_point(),
x_T=x_T, deg_around_x0_xT_box=1,
temp_horizon_in_s=3600)

# x_0.date_time = 2021-11-24 12:00:00+00:00
ax = arena.ocean_field.hindcast_data_source.plot_data_at_time_over_area(
time=x_0.date_time, x_interval=lon_bnds, y_interval=lat_bnds, return_ax=True)
problem.plot(ax=ax)
plt.show()
# %% Try to do the same as above for the longterm average
print("Longterm Average")
t_interval, lat_bnds, lon_bnds = arena.ocean_field.forecast_data_source.convert_to_x_y_time_bounds(
x_0=x_0.to_spatio_temporal_point(),
x_T=x_T, deg_around_x0_xT_box=1,
temp_horizon_in_s=3600)

# # x_0.date_time = 2021-11-24 12:00:00+00:00 -> should just plot the forecast stuff
time_forecast = datetime.datetime(2021, 11, 24, 12, 0, tzinfo=datetime.timezone.utc)
ax = arena.ocean_field.forecast_data_source.plot_data_at_time_over_area(
time=time_forecast, x_interval=lon_bnds, y_interval=lat_bnds, return_ax=True)
problem.plot(ax=ax)
plt.show()

time_average = datetime.datetime(2021, 12, 1, 12, 0, tzinfo=datetime.timezone.utc)
ax = arena.ocean_field.forecast_data_source.plot_data_at_time_over_area(
time=time_average, x_interval=lon_bnds, y_interval=lat_bnds, return_ax=True)
problem.plot(ax=ax)
plt.show()

d = arena.ocean_field.forecast_data_source.get_data_over_area(x_interval=lon_bnds, y_interval=lat_bnds, t_interval=[time_forecast, time_average])
print(d)
# %% Instantiate the HJ Planner
# specific_settings = {
# 'replan_on_new_fmrc': True,
# 'replan_every_X_seconds': False,
# 'direction': 'multi-time-reach-back',
# 'n_time_vector': 200,
# # Note that this is the number of time-intervals, the vector is +1 longer because of init_time
# 'deg_around_xt_xT_box': 1., # area over which to run HJ_reachability
# 'accuracy': 'high',
# 'artificial_dissipation_scheme': 'local_local',
# 'T_goal_in_seconds': 3600 * 24 * 5,
# 'use_geographic_coordinate_system': True,
# 'progress_bar': True,
# 'initial_set_radii': [0.1, 0.1], # this is in deg lat, lon. Note: for Backwards-Reachability this should be bigger.
# # Note: grid_res should always be bigger than initial_set_radii, otherwise reachability behaves weirdly.
# 'grid_res': 0.02, # Note: this is in deg lat, lon (HYCOM Global is 0.083 and Mexico 0.04)
# 'd_max': 0.0,
# # 'EVM_threshold': 0.3 # in m/s error when floating in forecasted vs sensed currents
# # 'fwd_back_buffer_in_seconds': 0.5, # this is the time added to the earliest_to_reach as buffer for forward-backward
# 'platform_dict': arena.platform.platform_dict
# }
# planner = HJReach2DPlanner(problem=problem, specific_settings=specific_settings)

# # % Run reachability planner
# observation = arena.reset(platform_state=x_0)
# action = planner.get_action(observation=observation)
# # %% Various plotting of the reachability computations
# planner.plot_reachability_snapshot(rel_time_in_seconds=0, granularity_in_h=5, alpha_color=1, time_to_reach=True, fig_size_inches=(12, 12), plot_in_h=True)
# # planner.plot_reachability_snapshot_over_currents(rel_time_in_seconds=0, granularity_in_h=5, time_to_reach=False)
# # planner.plot_reachability_animation(time_to_reach=True, granularity_in_h=5, filename="test_reach_animation.mp4")
# # planner.plot_reachability_animation(time_to_reach=True, granularity_in_h=5, with_opt_ctrl=True,
# # filename="test_reach_animation_w_ctrl.mp4", forward_time=True)
# #%% save planner state and reload it
# # Save it to a folder
# # planner.save_planner_state("saved_planner/")
# # Load it from the folder
# # loaded_planner = HJReach2DPlanner.from_saved_planner_state(folder="saved_planner/", problem=problem)
# # loaded_planner.last_data_source = arena.ocean_field.hindcast_data_source
# # observation = arena.reset(platform_state=x_0)
# # loaded_planner._update_current_data(observation=observation)
# # planner = loaded_planner
# # %% Let the controller run closed-loop within the arena (the simulation loop)
# for i in tqdm(range(int(3600 * 24 * 3 / 600))): # 3 days
# action = planner.get_action(observation=observation)
# observation = arena.step(action)

# #%% Plot the arena trajectory on the map
# arena.plot_all_on_map(problem=problem)
# #%% Animate the trajectory
# arena.animate_trajectory(problem=problem, temporal_resolution=7200)
2 changes: 1 addition & 1 deletion scripts/tutorial/controller/hj_planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@
#%% Plot the arena trajectory on the map
arena.plot_all_on_map(problem=problem)
#%% Animate the trajectory
arena.animate_trajectory(problem=problem, temporal_resolution=7200)
arena.animate_trajectory(problem=problem, temporal_resolution=7200)
Loading