-
Notifications
You must be signed in to change notification settings - Fork 215
Setting all the optimizers to have useLocking = True #310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,19 +2,34 @@ | |
|
||
import org.junit.jupiter.api.*; | ||
import org.tensorflow.Graph; | ||
import org.tensorflow.Session; | ||
import org.tensorflow.Tensor; | ||
import org.tensorflow.framework.initializers.Glorot; | ||
import org.tensorflow.framework.initializers.VarianceScaling; | ||
import org.tensorflow.framework.utils.TestSession; | ||
import org.tensorflow.ndarray.FloatNdArray; | ||
import org.tensorflow.ndarray.Shape; | ||
import org.tensorflow.ndarray.buffer.DataBuffers; | ||
import org.tensorflow.op.Op; | ||
import org.tensorflow.op.Ops; | ||
import org.tensorflow.op.core.Assign; | ||
import org.tensorflow.op.core.Constant; | ||
import org.tensorflow.op.core.Init; | ||
import org.tensorflow.op.core.Placeholder; | ||
import org.tensorflow.op.core.Variable; | ||
import org.tensorflow.op.math.Add; | ||
import org.tensorflow.op.math.Mean; | ||
import org.tensorflow.op.nn.Relu; | ||
import org.tensorflow.proto.framework.ConfigProto; | ||
import org.tensorflow.proto.framework.GraphDef; | ||
import org.tensorflow.types.TFloat32; | ||
import org.tensorflow.types.family.TType; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
|
||
import static org.junit.jupiter.api.Assertions.assertArrayEquals; | ||
import static org.junit.jupiter.api.Assertions.assertEquals; | ||
|
||
/** Test cases for GradientDescent Optimizer */ | ||
|
@@ -97,4 +112,158 @@ public void testBasic() { | |
session.evaluate(expectedVar1, var1); | ||
} | ||
} | ||
|
||
// This test fails due to initialization and gradient issues. It should not, but it seems to be a | ||
// problem | ||
// in TF-core. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reformat this comment? |
||
@Disabled | ||
@Test | ||
public void testDeterminism() { | ||
ConfigProto config = | ||
ConfigProto.newBuilder() | ||
.setIntraOpParallelismThreads(1) | ||
.setInterOpParallelismThreads(1) | ||
.build(); | ||
|
||
GraphDef def; | ||
String initName; | ||
String trainName; | ||
String lossName; | ||
|
||
String fcWeightName, fcBiasName, outputWeightName, outputBiasName; | ||
|
||
try (Graph g = new Graph()) { | ||
Ops tf = Ops.create(g); | ||
|
||
Glorot<TFloat32> initializer = | ||
new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); | ||
// Inputs | ||
Placeholder<TFloat32> input = | ||
tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20))); | ||
|
||
// Fully connected layer | ||
Variable<TFloat32> fcWeights = | ||
tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class)); | ||
fcWeightName = fcWeights.op().name(); | ||
Variable<TFloat32> fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); | ||
fcBiasName = fcBiases.op().name(); | ||
Relu<TFloat32> relu = tf.nn.relu(tf.math.add(tf.linalg.matMul(input, fcWeights), fcBiases)); | ||
|
||
// Output layer | ||
Variable<TFloat32> outputWeights = | ||
tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class)); | ||
outputWeightName = outputWeights.op().name(); | ||
Variable<TFloat32> outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); | ||
outputBiasName = outputBiases.op().name(); | ||
Add<TFloat32> output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); | ||
|
||
// Loss | ||
Placeholder<TFloat32> placeholder = | ||
tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); | ||
Mean<TFloat32> loss = | ||
tf.math.mean( | ||
tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); | ||
lossName = loss.op().name(); | ||
|
||
GradientDescent gd = new GradientDescent(g, 10.0f); | ||
Op trainingOp = gd.minimize(loss); | ||
trainName = trainingOp.op().name(); | ||
|
||
// Create the init op | ||
Init init = tf.init(); | ||
initName = init.op().name(); | ||
|
||
def = g.toGraphDef(); | ||
} | ||
|
||
float[] data = | ||
new float[] { | ||
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, -8.0f, -9.0f, 10.0f, 11.0f, 12.0f, 13.0f, | ||
-14.0f, -15.0f, 0.16f, 0.17f, 0.18f, 1.9f, 0.2f | ||
}; | ||
TFloat32 dataTensor = TFloat32.tensorOf(Shape.of(1, 20), DataBuffers.of(data)); | ||
float[] target = new float[] {0.2f, 0.8f}; | ||
TFloat32 targetTensor = TFloat32.tensorOf(Shape.of(1, 2), DataBuffers.of(target)); | ||
|
||
int numRuns = 20; | ||
List<List<Tensor>> initialized = new ArrayList<>(numRuns); | ||
List<List<Tensor>> trained = new ArrayList<>(numRuns); | ||
float[] initialLoss = new float[numRuns]; | ||
float[] postTrainingLoss = new float[numRuns]; | ||
|
||
for (int i = 0; i < numRuns; i++) { | ||
try (Graph g = new Graph(); | ||
Session s = new Session(g, config)) { | ||
g.importGraphDef(def); | ||
s.run(initName); | ||
|
||
initialized.add( | ||
s.runner() | ||
.fetch(fcWeightName) | ||
.fetch(fcBiasName) | ||
.fetch(outputWeightName) | ||
.fetch(outputBiasName) | ||
.run()); | ||
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we avoid the verbosity in the unit test, aren't the equality check enough to validate? I'm fine just commenting out these |
||
|
||
TFloat32 lossVal = (TFloat32) s.runner() | ||
.addTarget(trainName) | ||
.feed("input", dataTensor) | ||
.feed("output", targetTensor) | ||
.fetch(lossName) | ||
.run().get(0); | ||
initialLoss[i] = lossVal.getFloat(); | ||
lossVal.close(); | ||
|
||
trained.add( | ||
s.runner() | ||
.fetch(fcWeightName) | ||
.fetch(fcBiasName) | ||
.fetch(outputWeightName) | ||
.fetch(outputBiasName) | ||
.run()); | ||
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); | ||
System.out.println("Trained - " + ndArrToString((TFloat32)trained.get(i).get(3))); | ||
|
||
lossVal = (TFloat32) s.runner() | ||
.addTarget(trainName) | ||
.feed("input", dataTensor) | ||
.feed("output", targetTensor) | ||
.fetch(lossName) | ||
.run().get(0); | ||
postTrainingLoss[i] = lossVal.getFloat(); | ||
lossVal.close(); | ||
} | ||
} | ||
|
||
for (int i = 1; i < numRuns; i++) { | ||
assertEquals(initialLoss[0],initialLoss[i]); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Super nit: spaces after commas. |
||
assertEquals(postTrainingLoss[0],postTrainingLoss[i]); | ||
// Because the weights are references not copies. | ||
assertEquals(initialized.get(i),trained.get(i)); | ||
assertEquals( | ||
initialized.get(0), | ||
initialized.get(i), | ||
"Variables not initialized identically (0," + i + ")"); | ||
assertEquals( | ||
trained.get(0), trained.get(i), "Variables not trained identically (0," + i + ")"); | ||
} | ||
|
||
for (List<Tensor> curInit : initialized) { | ||
for (Tensor t : curInit) { | ||
t.close(); | ||
} | ||
} | ||
for (List<Tensor> curTrained : trained) { | ||
for (Tensor t : curTrained) { | ||
t.close(); | ||
} | ||
} | ||
} | ||
|
||
private static String ndArrToString(FloatNdArray ndarray) { | ||
StringBuffer sb = new StringBuffer(); | ||
ndarray.scalars().forEachIndexed((idx,array) -> sb.append(Arrays.toString(idx)).append(" = ").append(array.getFloat()).append("\n")); | ||
return sb.toString(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nit: any formatter will probably complain about the missing space after a comma.