Open
Description
Per our discussion on Gitter, here is a possible implementation for converting Tensors to a String representation. It is still missing some important features, like collapsing long arrays using ellipses, but this can serve as a stepping stone. The functionality is meant to ease troubleshooting/debugging so performance should not be an issue.
import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;
import java.util.StringJoiner;
public final class Tensors
{
private final Session session;
/**
* @param session the session used by all operations
*/
public Tensors(Session session)
{
this.session = session;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TFloat64 tensor)
{
Shape shape = tensor.shape();
DoubleDataBuffer doubles = tensor.asRawTensor().data().asDoubles();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TFloat32 tensor)
{
Shape shape = tensor.shape();
FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TFloat16 tensor)
{
Shape shape = tensor.shape();
FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TInt64 tensor)
{
Shape shape = tensor.shape();
LongDataBuffer doubles = tensor.asRawTensor().data().asLongs();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TInt32 tensor)
{
Shape shape = tensor.shape();
IntDataBuffer doubles = tensor.asRawTensor().data().asInts();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param tensor a tensor
* @return the String representation of the tensor
*/
public String toString(TUint8 tensor)
{
Shape shape = tensor.shape();
ShortDataBuffer doubles = tensor.asRawTensor().data().asShorts();
return toString(doubles, shape, 0, 0, tensor.rank()).text;
}
/**
* @param data the data
* @param shape the shape of the tensor
* @param index the index of the tensor element to start at
* @param dimension the current dimension
* @param rank the maximum dimension
* @return the String representation of the {@code dimension}
*/
private ToStringResponse toString(DataBuffer<?> data, Shape shape, int index, int dimension, int rank)
{
int numElements = 0;
StringJoiner joiner;
if (dimension < rank)
{
joiner = new StringJoiner(",\n", "\t".repeat(dimension) + "[\n", "\n" + "\t".repeat(dimension) + "]");
for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
{
ToStringResponse response = toString(data, shape, index, dimension + 1, rank);
joiner.add(response.text);
numElements += response.numElements;
index += response.numElements;
}
}
else
{
joiner = new StringJoiner(",", "\t".repeat(dimension) + "[", "]");
for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
{
joiner.add(String.valueOf(data.getObject(index)));
++numElements;
++index;
}
}
return new ToStringResponse(joiner.toString(), numElements);
}
/**
* @param text the string representation of a tensor dimension
* @param numElements the number of elements contained in {@code text}
*/
private record ToStringResponse(String text, int numElements)
{
}
}
Metadata
Metadata
Assignees
Labels
No labels