1717
1818package org .tensorflow .ndarray ;
1919
20+ import java .util .ArrayList ;
2021import java .util .Arrays ;
22+ import java .util .List ;
2123
2224/**
2325 * The shape of a Tensor or {@link NdArray}.
@@ -74,8 +76,8 @@ public static Shape scalar() {
7476 * Shape scalar = Shape.of()
7577 * }</pre>
7678 *
77- * @param dimensionSizes number of elements in each dimension of this shape, if any, or
78- * {@link Shape#UNKNOWN_SIZE} if unknown.
79+ * @param dimensionSizes number of elements in each dimension of this shape, if any, or {@link
80+ * Shape#UNKNOWN_SIZE} if unknown.
7981 * @return a new shape
8082 */
8183 public static Shape of (long ... dimensionSizes ) {
@@ -108,13 +110,34 @@ public long size() {
108110 * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109111 *
110112 * @param i the index of the dimension to get the size for. If this Shape has a known number of
111- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
112- * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113- * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
113+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
114+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
115+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
116+ * dimension etc.
114117 * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
115118 * otherwise.
119+ * @deprecated Renamed to {@link #get(int)}.
116120 */
117- public long size (int i ) {
121+ @ Deprecated
122+ public long size (int i ){
123+ return get (i );
124+ }
125+
126+ /**
127+ * The size of the dimension with the given index.
128+ *
129+ * <p>If {@link Shape#isUnknown()} is true or the size of the dimension with the given index has
130+ * an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
131+ *
132+ * @param i the index of the dimension to get the size for. If this Shape has a known number of
133+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in
134+ * which case the position is counted from the end of the shape. E.g.: {@code size(-1)}
135+ * returns the size of the last dimension, {@code size(-2)} the size of the second to last
136+ * dimension etc.
137+ * @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
138+ * otherwise.
139+ */
140+ public long get (int i ) {
118141 if (dimensionSizes == null ) {
119142 return UNKNOWN_SIZE ;
120143 } else if (i >= 0 ) {
@@ -177,6 +200,24 @@ public long[] asArray() {
177200 }
178201 }
179202
203+ /**
204+ * Returns a defensive copy of the this Shape's axes. Changes to the returned list do not change
205+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
206+ */
207+ public List <Long > toListOrNull () {
208+ long [] array = asArray ();
209+ if (array == null ) {
210+ return null ;
211+ }
212+
213+ List <Long > list = new ArrayList <>(array .length );
214+ for (long l : array ) {
215+ list .add (l );
216+ }
217+
218+ return list ;
219+ }
220+
180221 @ Override
181222 public int hashCode () {
182223 return dimensionSizes != null ? Arrays .hashCode (dimensionSizes ) : super .hashCode ();
@@ -186,6 +227,7 @@ public int hashCode() {
186227 * Equals implementation for Shapes. Two Shapes are considered equal iff:
187228 *
188229 * <p>
230+ *
189231 * <ul>
190232 * <li>the number of dimensions is defined and equal for both
191233 * <li>the size of each dimension is defined and equal for both
@@ -236,7 +278,8 @@ public Shape head() {
236278 * Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237279 * shape
238280 *
239- * @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
281+ * @param n the number of leading dimensions to get, must be <= than {@link
282+ * Shape#numDimensions()}
240283 * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241284 * this Shape
242285 */
@@ -252,7 +295,9 @@ public Shape take(int n) {
252295
253296 /** Returns a new Shape, with this Shape's first dimension removed. */
254297 public Shape tail () {
255- if (dimensionSizes .length < 2 ) return Shape .of ();
298+ if (dimensionSizes .length < 2 ) {
299+ return Shape .of ();
300+ }
256301 return Shape .of (Arrays .copyOfRange (dimensionSizes , 1 , dimensionSizes .length ));
257302 }
258303
@@ -276,15 +321,21 @@ public Shape takeLast(int n) {
276321 }
277322
278323 /**
279- * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code begin} to {@code end}.
324+ * Return a {@code end - begin} dimensional shape with dimensions matching this Shape from {@code
325+ * begin} to {@code end}.
326+ *
280327 * @param begin Where to start the sub-shape.
281328 * @param end Where to end the sub-shape, exclusive.
282329 * @return the sub-shape bounded by begin and end.
283330 */
284- public Shape subShape (int begin , int end ){
331+ public Shape subShape (int begin , int end ) {
285332 if (end > numDimensions ()) {
286333 throw new ArrayIndexOutOfBoundsException (
287- "End index " + end + " out of bounds: shape only has " + numDimensions () + " dimensions." );
334+ "End index "
335+ + end
336+ + " out of bounds: shape only has "
337+ + numDimensions ()
338+ + " dimensions." );
288339 }
289340 if (begin < 0 ) {
290341 throw new ArrayIndexOutOfBoundsException (
@@ -423,7 +474,7 @@ public boolean isCompatibleWith(Shape shape) {
423474 return false ;
424475 }
425476 for (int i = 0 ; i < numDimensions (); i ++) {
426- if (!isCompatible (size (i ), shape .size (i ))) {
477+ if (!isCompatible (get (i ), shape .get (i ))) {
427478 return false ;
428479 }
429480 }
0 commit comments