Skip to content

Commit e1e0f08

Browse files
Feature: Implement RESHAPE() in TDI
This builtin was never implemented, but the opcode was allocated and it existed in the original documentation. The similarity to numpy's reshape function means users might actually want to use this, and throwing a terse error is less than helpful. This implements RESHAPE() with only the first two arguments of SOURCE and SHAPE. The optional arguments of PAD and ORDER described in the original documetnation are not implemented, and won't be included in the new documentation for this function. ``` TDI> reshape(1:6, [2, 3]) [[1,2], [3,4], [5,6]] TDI> reshape(1:6, [3, 2]) [[1,2,3], [4,5,6]] TDI> reshape(1:8, [2, 2, 2]) [[[1,2], [3,4]], [[5,6], [7,8]]] TDI> reshape(reshape(1:8, [2, 4]), [8]) [1,2,3,4,5,6,7,8] TDI> reshape([[1,2], [3,4]], [1, 4]) [[1], [2], [3], [4]] ```
1 parent 346e5d0 commit e1e0f08

1 file changed

Lines changed: 67 additions & 1 deletion

File tree

tdishr/TdiReshape.c

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2828
#include <mds_stdarg.h>
2929
#include <mdsdescrip.h>
3030
#include <tdishr.h>
31+
#include <mdsshr.h>
3132

3233
#include <string.h>
3334

@@ -49,7 +50,72 @@ int Tdi1Reshape(opcode_t opcode, int narg, struct descriptor *list[],
4950
default:
5051
return TdiNO_OPC;
5152
case OPC_RESHAPE:
52-
return TdiNO_OPC;
53+
{
54+
uint32_t shape[MAX_DIMS] = { 0 };
55+
mdsdsc_a_t dsc_shape = {
56+
.length = sizeof(uint32_t),
57+
.class = CLASS_A,
58+
.dtype = DTYPE_LU,
59+
.arsize = sizeof(shape),
60+
.pointer = (char *)shape,
61+
};
62+
63+
RETURN_IF_NOT_OK(TdiConvert((mdsdsc_a_t *)list[1], (mdsdsc_a_t *)&dsc_shape));
64+
65+
size_t dimct = 0;
66+
size_t original_count = (arr->arsize / arr->length);
67+
size_t new_count = 1;
68+
for (size_t i = 0; i < MAX_DIMS; ++i) {
69+
if (shape[i] == 0) {
70+
break;
71+
}
72+
73+
++dimct;
74+
new_count *= shape[i];
75+
}
76+
77+
if (original_count != new_count) {
78+
return TdiMISMATCH;
79+
}
80+
81+
if (dimct <= arr->dimct) {
82+
arr->dimct = dimct;
83+
for (size_t i = 0; i < dimct; ++i) {
84+
arr->m[i] = shape[i];
85+
}
86+
}
87+
else {
88+
// Need to resize the array descriptor to fit the new dims
89+
90+
array_coeff new_array = {
91+
.length = arr->length,
92+
.dtype = arr->dtype,
93+
.class = arr->class,
94+
.pointer = arr->pointer,
95+
.scale = arr->scale,
96+
.digits = arr->digits,
97+
.aflags = arr->aflags,
98+
.dimct = dimct,
99+
.arsize = arr->arsize,
100+
};
101+
102+
new_array.aflags.coeff = 1;
103+
new_array.dimct = dimct;
104+
new_array.a0 = new_array.pointer;
105+
for (size_t i = 0; i < dimct; ++i) {
106+
new_array.m[i] = shape[i];
107+
}
108+
109+
mdsdsc_xd_t tmp_out_ptr = MDSDSC_XD_INITIALIZER;
110+
MdsCopyDxXd((mdsdsc_t *)&new_array, &tmp_out_ptr);
111+
112+
// Use the new array, not the original one
113+
MdsFree1Dx(out_ptr, NULL);
114+
memcpy(out_ptr, &tmp_out_ptr, sizeof(mdsdsc_xd_t));
115+
}
116+
117+
break;
118+
}
53119
case OPC_FLATTEN:
54120
{
55121
if (arr->dimct > 1)

0 commit comments

Comments
 (0)