Skip to content

Commit

Permalink
Change inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jan 24, 2024
1 parent f93f39c commit 595c1df
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def forward(self, x):
modes["pv"] = self.pv_encoder(x)
else:
# Target is PV, so only take the history
x[BatchKey.pv] = x[BatchKey.pv][:, : self.history_len_30]
modes["pv"] = self.pv_encoder(x)
# Copy batch
x_tmp = x.copy()
x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len_30]
modes["pv"] = self.pv_encoder(x_tmp)

# *********************** GSP Data ************************************
# add gsp yield history
Expand All @@ -283,9 +285,10 @@ def forward(self, x):
modes["wind"] = self.wind_encoder(x)
else:
# Have to be its own Batch format
x[BatchKey.wind] = x[BatchKey.wind][:, : self.history_len_30]
x_tmp = x.copy()
x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len_30]
# This needs to be a Batch as input
modes["wind"] = self.wind_encoder(x)
modes["wind"] = self.wind_encoder(x_tmp)

if self.include_sun:
sun = torch.cat(
Expand Down

0 comments on commit 595c1df

Please sign in to comment.