Skip to content

Commit

Permalink
Fix emulation error of lstm by 'backward' and 'both' direction options
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Jan 14, 2025
1 parent df555e1 commit 88a35ff
Showing 1 changed file with 73 additions and 44 deletions.
117 changes: 73 additions & 44 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -5554,28 +5554,38 @@ partial dictionary MLOpSupportLimits {
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
const batchSize = input.shape[1];
const inputSize = input.shape[2];
const numDirections = (options.direction == 'both' ? 2 : 1);
const direction = options.direction || 'forward';
const numDirections = (direction == 'both' ? 2 : 1);
let hiddenState = options.initialHiddenState;
let cellState = options.initialCellState;

if (!hiddenState) {
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
const desc = {
dataType: 'float32',
shape: [numDirections, batchSize, hiddenSize]
};
const totalSize = numDirections * batchSize * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}

if (!cellState) {
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
const desc = {
dataType: 'float32',
shape: [numDirections, batchSize, hiddenSize]
};
const totalSize = numDirections * batchSize * hiddenSize;
cellState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}

let sequence = null;
let currentWeight = [];
let currentRecurrentWeight = [];
let currentBias = [];
let currentRecurrentBias = [];
let currentPeepholeWeight = [];
let forwardSequence = null;
let backwardSequence = null;
let outputHidden = null;
let outputCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(
Expand Down Expand Up @@ -5605,36 +5615,26 @@ partial dictionary MLOpSupportLimits {
builder.slice(
options.peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) :
null);
}

for (let step = 0; step < steps; ++step) {
let currentHidden = [];
let currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(
builder,
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(
builder,
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}
let currentHidden = squeeze(
builder,
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
let currentCell = squeeze(
builder,
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]));

for (let dir = 0; dir < numDirections; ++dir) {
let slice =
(dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
let currentInput = squeeze(
for (let step = 0; step < steps; ++step) {
const slice = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(
builder,
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));

let results = builder.lstmCell(
[currentHidden, currentCell] = builder.lstmCell(
currentInput,
currentWeight[dir],
currentRecurrentWeight[dir],
currentHidden[dir],
currentCell[dir],
currentHidden,
currentCell,
hiddenSize,
{
bias: currentBias[dir],
Expand All @@ -5644,27 +5644,56 @@ partial dictionary MLOpSupportLimits {
activations: options.activations
});

let output = builder.reshape(results[0], [1, batchSize, hiddenSize]);
let cell = builder.reshape(results[1], [1, batchSize, hiddenSize]);

nextHidden =
(nextHidden ? builder.concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? builder.concat([nextCell, cell], 0) : cell);
if (options.returnSequence) {
// Expand currentHidden of 2D([batchSize, hiddenSize])
// to 4D([steps, numDirections, batchSize, hiddenSize])
const expandedHiddenAs4D = builder.reshape(
currentHidden, [1, 1, batchSize, hiddenSize]);

if (direction == 'forward' || (dir == 0 && direction == 'both')) {
forwardSequence = forwardSequence ?
builder.concat([forwardSequence, expandedHiddenAs4D], 0) :
expandedHiddenAs4D;
} else if (direction == 'backward' || (dir == 1 && direction == 'both')) {
backwardSequence = backwardSequence ?
builder.concat([expandedHiddenAs4D, backwardSequence], 0) :
expandedHiddenAs4D;
}
}
}

hiddenState = nextHidden;
cellState = nextCell;
// Expand currentHidden of 2D([batchSize, hiddenSize])
// to 3D([numDirections, batchSize, hiddenSize])
const expandedHiddenAs3D = builder.reshape(
currentHidden, [1, batchSize, hiddenSize]);
outputHidden = outputHidden ?
builder.concat([outputHidden, expandedHiddenAs3D], 0) :
expandedHiddenAs3D;

// Expand currentCell of 2D([batchSize, hiddenSize])
// to 3D([numDirections, batchSize, hiddenSize])
const expandedCellAs3D = builder.reshape(
currentCell, [1, batchSize, hiddenSize]);
outputCell = outputCell ?
builder.concat([outputCell, expandedCellAs3D], 0) : expandedCellAs3D;
}

if (options.returnSequence) {
nextHidden =
builder.reshape(nextHidden, [1, numDirections, batchSize, hiddenSize]);
sequence =
(sequence ? builder.concat([sequence, nextHidden], 0) : nextHidden);
if (options.returnSequence) {
let outputSequence = null;

if (direction == 'forward') {
outputSequence = forwardSequence;
} else if (direction == 'backward') {
outputSequence = backwardSequence;
} else if (direction == 'both') {
// Concat along axis 1 (numDirections dimension)
outputSequence = builder.concat([forwardSequence, backwardSequence], 1);
}
}

return (
sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
return [outputHidden, outputCell, outputSequence];
} else {
return [outputHidden, outputCell];
}
}
</pre>
</details>
Expand Down

0 comments on commit 88a35ff

Please sign in to comment.