Skip to content

Commit 5ba1402

Browse files
Sequential Module Train to Eval mode update (#1443)
* update train and eval methods * update unit test * update releasenotes * add comments for test and fix releasenotes
1 parent 3760ba3 commit 5ba1402

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

RELEASENOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
## TorchSharp Release Notes
22

33
Releases, starting with 9/2/2021, are listed with the most recent release at the top.
4+
# NuGet Version 0.105.1
45

6+
__Bug Fixes__:
7+
8+
#1426 Sequential.eval() does not put model into eval mode<br/>
59
# NuGet Version 0.105.0
610

711
Move to libtorch 2.5.1. As with the 2.4.0 release, MacOS / Intel is no longer supported by libtorch, so TorchSharp doesn, either.

src/TorchSharp/NN/GenericSequential.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ protected override void Dispose(bool disposing)
221221
/// </remarks>
222222
public override void train(bool on = true)
223223
{
224+
base.train(on);
224225
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
225226
}
226227

@@ -232,6 +233,7 @@ public override void train(bool on = true)
232233
/// </remarks>
233234
public override void eval()
234235
{
236+
base.eval();
235237
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
236238
}
237239

@@ -465,6 +467,7 @@ protected override void Dispose(bool disposing)
465467
/// </remarks>
466468
public override void train(bool on = true)
467469
{
470+
base.train(on);
468471
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
469472
}
470473

@@ -476,6 +479,7 @@ public override void train(bool on = true)
476479
/// </remarks>
477480
public override void eval()
478481
{
482+
base.eval();
479483
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
480484
}
481485

@@ -702,6 +706,7 @@ protected override void Dispose(bool disposing)
702706
/// </remarks>
703707
public override void train(bool on = true)
704708
{
709+
base.train(on);
705710
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
706711
}
707712

@@ -713,6 +718,7 @@ public override void train(bool on = true)
713718
/// </remarks>
714719
public override void eval()
715720
{
721+
base.eval();
716722
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
717723
}
718724

@@ -939,6 +945,7 @@ protected override void Dispose(bool disposing)
939945
/// </remarks>
940946
public override void train(bool on = true)
941947
{
948+
base.train(on);
942949
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
943950
}
944951

@@ -950,6 +957,7 @@ public override void train(bool on = true)
950957
/// </remarks>
951958
public override void eval()
952959
{
960+
base.eval();
953961
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
954962
}
955963

@@ -1176,6 +1184,7 @@ protected override void Dispose(bool disposing)
11761184
/// </remarks>
11771185
public override void train(bool on = true)
11781186
{
1187+
base.train(on);
11791188
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
11801189
}
11811190

@@ -1187,6 +1196,7 @@ public override void train(bool on = true)
11871196
/// </remarks>
11881197
public override void eval()
11891198
{
1199+
base.eval();
11901200
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
11911201
}
11921202

@@ -1413,6 +1423,7 @@ protected override void Dispose(bool disposing)
14131423
/// </remarks>
14141424
public override void train(bool on = true)
14151425
{
1426+
base.train(on);
14161427
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
14171428
}
14181429

@@ -1424,6 +1435,7 @@ public override void train(bool on = true)
14241435
/// </remarks>
14251436
public override void eval()
14261437
{
1438+
base.eval();
14271439
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
14281440
}
14291441

src/TorchSharp/NN/Sequential.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ protected override void Dispose(bool disposing)
253253
/// </remarks>
254254
public override void train(bool on = true)
255255
{
256+
base.train(on);
256257
foreach (var m in _modules) { ((torch.nn.Module)m).train(on); }
257258
}
258259

@@ -264,6 +265,7 @@ public override void train(bool on = true)
264265
/// </remarks>
265266
public override void eval()
266267
{
268+
base.eval();
267269
foreach (var m in _modules) { ((torch.nn.Module)m).eval(); }
268270
}
269271

test/TorchSharpTest/TestTraining.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,34 @@ namespace TorchSharp
1616
public class TestTraining
1717
{
1818

19+
// <summary>
20+
/// Check if sequential module goes from training to eval mode
21+
// </summary>
22+
[Fact]
23+
public void TestTrainingEvalModes()
24+
{
25+
var sequential = Sequential(
26+
("lin1", Linear(100, 10)),
27+
("lin2", Linear(10, 5))
28+
);
29+
sequential.eval();
30+
// The entire sequential module and its layers should be in evaluation mode, not in training
31+
var firstLayer = (torch.nn.Module)sequential[0];
32+
Assert.False(firstLayer.training);
33+
34+
var secondLayer = (torch.nn.Module)sequential[1];
35+
Assert.False(secondLayer.training);
36+
37+
Assert.False(sequential.training);
38+
39+
sequential.train();
40+
// The entire sequential module and its layers should be in training mode, not in evaluation
41+
Assert.True(firstLayer.training);
42+
Assert.True(secondLayer.training);
43+
Assert.True(sequential.training);
44+
}
45+
46+
1947
/// <summary>
2048
/// Fully connected ReLU net with one hidden layer trained using gradient descent.
2149
/// Taken from <see href="https://pytorch.org/tutorials/beginner/examples_nn/two_layer_net_nn.html"/>.

0 commit comments

Comments
 (0)