Skip to content

Ability to convert Tensor to String representation #268

Open
@cowwoc

Description

@cowwoc

Per our discussion on Gitter, here is a possible implementation for converting Tensors to a String representation. It is still missing some important features, like collapsing long arrays using ellipses, but this can serve as a stepping stone. The functionality is meant to ease troubleshooting/debugging so performance should not be an issue.

import org.tensorflow.Session;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffer;
import org.tensorflow.ndarray.buffer.DoubleDataBuffer;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.ndarray.buffer.IntDataBuffer;
import org.tensorflow.ndarray.buffer.LongDataBuffer;
import org.tensorflow.ndarray.buffer.ShortDataBuffer;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.TUint8;

import java.util.StringJoiner;

public final class Tensors
{
	private final Session session;

	/**
	 * @param session the session used by all operations
	 */
	public Tensors(Session session)
	{
		this.session = session;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat64 tensor)
	{
		Shape shape = tensor.shape();
		DoubleDataBuffer doubles = tensor.asRawTensor().data().asDoubles();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat32 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TFloat16 tensor)
	{
		Shape shape = tensor.shape();
		FloatDataBuffer doubles = tensor.asRawTensor().data().asFloats();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt64 tensor)
	{
		Shape shape = tensor.shape();
		LongDataBuffer doubles = tensor.asRawTensor().data().asLongs();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TInt32 tensor)
	{
		Shape shape = tensor.shape();
		IntDataBuffer doubles = tensor.asRawTensor().data().asInts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param tensor a tensor
	 * @return the String representation of the tensor
	 */
	public String toString(TUint8 tensor)
	{
		Shape shape = tensor.shape();
		ShortDataBuffer doubles = tensor.asRawTensor().data().asShorts();
		return toString(doubles, shape, 0, 0, tensor.rank()).text;
	}

	/**
	 * @param data      the data
	 * @param shape     the shape of the tensor
	 * @param index     the index of the tensor element to start at
	 * @param dimension the current dimension
	 * @param rank      the maximum dimension
	 * @return the String representation of the {@code dimension}
	 */
	private ToStringResponse toString(DataBuffer<?> data, Shape shape, int index, int dimension, int rank)
	{
		int numElements = 0;
		StringJoiner joiner;
		if (dimension < rank)
		{
			joiner = new StringJoiner(",\n", "\t".repeat(dimension) + "[\n", "\n" + "\t".repeat(dimension) + "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				ToStringResponse response = toString(data, shape, index, dimension + 1, rank);
				joiner.add(response.text);
				numElements += response.numElements;
				index += response.numElements;
			}
		}
		else
		{
			joiner = new StringJoiner(",", "\t".repeat(dimension) + "[", "]");
			for (long i = 0, size = shape.size(rank - 1); i < size; ++i)
			{
				joiner.add(String.valueOf(data.getObject(index)));
				++numElements;
				++index;
			}
		}
		return new ToStringResponse(joiner.toString(), numElements);
	}

	/**
	 * @param text        the string representation of a tensor dimension
	 * @param numElements the number of elements contained in {@code text}
	 */
	private record ToStringResponse(String text, int numElements)
	{
	}
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions