Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix output sequence issue of lstm with backward direction and both direction #123

Merged
merged 7 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions src/lstm.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
initialCellState, new Array(sizeOfShape(initialCellState)).fill(0));
}

let sequence;
const currentWeight = [];
const currentRecurrentWeight = [];
const currentBias = [];
const currentRecurrentBias = [];
const currentPeepholeWeight = [];
let forwardSequence = null;
let backwardSequence = null;
let currentHidden;
let currentCell;
let outputHidden;
let outputCell;

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(slice(weight, [dir, 0, 0], [1, 4 * hiddenSize, inputSize])));
Expand All @@ -63,44 +68,60 @@ export function lstm(input, weight, recurrentWeight, steps, hiddenSize,
(squeeze(slice(recurrentBias, [dir, 0], [1, 4 * hiddenSize]))) : null);
currentPeepholeWeight.push(peepholeWeight ?
(squeeze(slice(peepholeWeight, [dir, 0], [1, 3 * hiddenSize]))) : null);
}

for (let step = 0; step < steps; ++step) {
const currentHidden = [];
const currentCell = [];
let nextHidden = null;
let nextCell = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
currentCell.push(squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}
currentHidden = squeeze(slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
currentCell = squeeze(slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]));

for (let dir = 0; dir < numDirections; ++dir) {
const slice0 = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
for (let step = 0; step < steps; ++step) {
const slice0 = dir === 1 || direction === 'backward' ? steps - step - 1 : step;
const currentInput = squeeze(slice(input, [slice0, 0, 0], [1, batchSize, inputSize]));

const results = lstmCell(
[currentHidden, currentCell] = lstmCell(
currentInput, currentWeight[dir], currentRecurrentWeight[dir],
currentHidden[dir], currentCell[dir], hiddenSize, {bias: currentBias[dir],
currentHidden, currentCell, hiddenSize, {bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir], peepholeWeight: currentPeepholeWeight[dir],
layout: layout, activations: activations});

const output = reshape(results[0], [1, null, hiddenSize]);
const cell = reshape(results[1], [1, null, hiddenSize]);

nextHidden = (nextHidden ? concat([nextHidden, output], 0) : output);
nextCell = (nextCell ? concat([nextCell, cell], 0) : cell);
if (returnSequence) {
// Expand hidden of 2D([batchSize, hiddenSize]) to
// 4D([steps, numDirections, batchSize, hiddenSize])
const expandedHiddenAs4D = reshape(currentHidden, [1, 1, batchSize, hiddenSize]);
if (direction === 'forward' || (dir === 0 && direction === 'both')) {
forwardSequence = forwardSequence ?
concat([forwardSequence, expandedHiddenAs4D], 0) :
expandedHiddenAs4D;
} else if (direction === 'backward' || (dir === 1 && direction === 'both')) {
backwardSequence = backwardSequence ?
concat([expandedHiddenAs4D, backwardSequence], 0) :
expandedHiddenAs4D;
}
}
}

hiddenState = nextHidden;
cellState = nextCell;
// Expand hidden of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize])
const expandHiddenAs3D = reshape(currentHidden, [1, batchSize, hiddenSize]);
// Concat along axis 0 (numDirections dimension)
outputHidden = outputHidden ? concat([outputHidden, expandHiddenAs3D], 0) : expandHiddenAs3D;

if (returnSequence) {
nextHidden = reshape(nextHidden, [1, numDirections, null, hiddenSize]);
sequence = (sequence ? concat([sequence, nextHidden], 0) : nextHidden);
}
// Expand cell of 2D([batchSize, hiddenSize]) to 3D([numDirections, batchSize, hiddenSize])
const expandCellAs3D = reshape(currentCell, [1, batchSize, hiddenSize]);
// Concat along axis 0 (numDirections dimension)
outputCell = outputCell ? concat([outputCell, expandCellAs3D], 0) : expandCellAs3D;
}

return (sequence ? [hiddenState, cellState, sequence] : [hiddenState, cellState]);
if (returnSequence) {
// outputSequence: [steps, numDirections, batchSize, hiddenSize]
let outputSequence;
if (direction === 'forward') {
outputSequence = forwardSequence;
} else if (direction === 'backward') {
outputSequence = backwardSequence;
} else if (direction === 'both') {
// Concat along axis 1 (numDirections dimension)
outputSequence = concat([forwardSequence, backwardSequence], 1);
}
return [outputHidden, outputCell, outputSequence];
} else {
return [outputHidden, outputCell];
}
}
74 changes: 67 additions & 7 deletions test/lstm_test.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ describe('test lstm', function() {
input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, returnSequence, activations});
console.log('outputs: ', outputs);
utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]);
Expand All @@ -65,7 +64,7 @@ describe('test lstm', function() {
}
});

it('lstm steps=2 direction="backward" returnSequence=true' +
it('lstm steps=2 direction="backward" returnSequence=true ' +
'activations=[relu, relu, relu]', function() {
const steps = 2;
const numDirections = 1;
Expand Down Expand Up @@ -106,22 +105,83 @@ describe('test lstm', function() {
input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, peepholeWeight, initialHiddenState,
initialCellState, direction, returnSequence, activations});
console.log('outputs: ', outputs);
utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]);
const expected = [
[10.469, 58.02899999999999, 74.529, 518.9490000000001],
[5.51, 20.009999999999998, 19.11, 75.21000000000001],
[
1,
8,
1,
8,
10.469,
58.02899999999999,
74.529,
518.9490000000001,
1,
8,
1,
8,
],
];
for (let i = 0; i < expected.length; ++i) {
utils.checkValue(outputs[i], expected[i]);
}
});

it('lstm steps=2 direction="both" returnSequence=true', function() {
const steps = 2;
const numDirections = 2;
const batchSize = 2;
const inputSize = 2;
const hiddenSize = 2;
const input = new Tensor([steps, batchSize, inputSize],
new Float32Array([1, 2, 2, 1, 3, 4, 1, 2]));
const weight = new Tensor([numDirections, 4 * hiddenSize, inputSize],
new Float32Array([
1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
1, -1, 2, -2, 1, -1, 2, -2,
]));
const recurrentWeight = new Tensor([numDirections, 4 * hiddenSize, hiddenSize],
new Array(2 * 4 * hiddenSize * hiddenSize).fill(0.1));
const bias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
]));
const recurrentBias = new Tensor([numDirections, 4 * hiddenSize],
new Float32Array([
1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2,
]));
const returnSequence = true;
const direction = 'both';
const outputs = lstm(
input, weight, recurrentWeight, steps, hiddenSize,
{bias, recurrentBias, direction, returnSequence});
utils.checkShape(outputs[0], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[1], [numDirections, batchSize, hiddenSize]);
utils.checkShape(outputs[2], [steps, numDirections, batchSize, hiddenSize]);
const expected = [
[
0.5764073262004139, 0.8236227651782412,
0.6612355785279247, 0.8442635760318142,
0.5764073262004139, 0.8236227651782412,
0.8635294727880538, 0.9491350760903781,
],
[
1.0171455721466105, 1.6205496282195793,
1.338846378789257, 1.7642604746965693,
1.0171455721466105, 1.6205496282195793,
1.485626937219704, 1.8449554199024933,
],
[
0.36960635293570576, 0.6082834181835157,
0.7037753329989016, 0.7586680430344475,
0.5764073262004139, 0.8236227651782412,
0.8635294727880538, 0.9491350760903781,
0.5764073262004139, 0.8236227651782412,
0.6612355785279247, 0.8442635760318142,
0.36960635293570576, 0.6082834181835157,
0.36960635293570576, 0.6082834181835157,
],
];
for (let i = 0; i < expected.length; ++i) {
Expand Down
Loading