File tree Expand file tree Collapse file tree 2 files changed +28
-7
lines changed Expand file tree Collapse file tree 2 files changed +28
-7
lines changed Original file line number Diff line number Diff line change 1818
1919DEVICES = torch .device ("cpu" )
2020
21- model_weights = load_attention_model_weights ()
22-
2321
2422class GameState (IntEnum ):
2523 CooperateDefect = 2
@@ -354,13 +352,20 @@ def __init__(
354352 self ,
355353 ) -> None :
356354 super ().__init__ ()
357- self .model = PlayerModel (PlayerConfig ())
358- self .model .load_state_dict (model_weights )
359- self .model .to (DEVICES )
360- self .model .eval ()
355+ self .model = None
356+
357+ def load_model (self ) -> None :
358+ """Load the model weights."""
359+ if self .model is None :
360+ self .model = PlayerModel (PlayerConfig ())
361+ self .model .load_state_dict (load_attention_model_weights ())
362+ self .model .to (DEVICES )
363+ self .model .eval ()
361364
362365 def strategy (self , opponent : Player ) -> Action :
363366 """Actual strategy definition that determines player's action."""
367+ # Load the model if not already loaded
368+ self .load_model ()
364369 # Compute features
365370 features = compute_features (self , opponent ).unsqueeze (0 ).to (DEVICES )
366371
Original file line number Diff line number Diff line change 11"""Tests for the Attention strategies."""
22
33import unittest
4+ from unittest .mock import patch
45
56import torch
67
78import axelrod as axl
9+ from axelrod .load_data_ import load_attention_model_weights
810from axelrod .strategies .attention import (
911 MEMORY_LENGTH ,
1012 GameState ,
@@ -89,7 +91,21 @@ class TestEvolvedAttention(TestPlayer):
8991 def test_model_initialization (self ):
9092 """Test that the model is initialized correctly."""
9193 player = self .player ()
92- self .assertIsInstance (player .model , PlayerModel )
94+ self .assertIsNone (player .model )
95+
96+ def test_load_model (self ):
97+ """Test that the model can be loaded correctly."""
98+ with patch (
99+ "axelrod.strategies.attention.load_attention_model_weights" ,
100+ wraps = load_attention_model_weights ,
101+ ) as load_attention_model_weights_spy :
102+ player = self .player ()
103+ self .assertIsNone (player .model )
104+ player .load_model ()
105+ self .assertIsInstance (player .model , PlayerModel )
106+ player .load_model ()
107+ self .assertIsInstance (player .model , PlayerModel )
108+ load_attention_model_weights_spy .assert_called_once ()
93109
94110 def test_versus_cooperator (self ):
95111 actions = [(C , C )] * 5
You can’t perform that action at this time.
0 commit comments