Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 245f7fa

Browse files
authored
Move reshape() and cast() to kernel backends. (#893)
* wip * wip * wip * cleanup * Add missing file
1 parent 9eb9b49 commit 245f7fa

File tree

5 files changed

+73
-20
lines changed

5 files changed

+73
-20
lines changed

src/kernels/backend.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
21
/**
32
* @license
4-
* Copyright 2017 Google Inc. All Rights Reserved.
3+
* Copyright 2018 Google Inc. All Rights Reserved.
54
* Licensed under the Apache License, Version 2.0 (the "License");
65
* you may not use this file except in compliance with the License.
76
* You may obtain a copy of the License at
@@ -19,7 +18,7 @@
1918
import {Conv2DInfo} from '../ops/conv_util';
2019
// tslint:disable-next-line:max-line-length
2120
import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor';
22-
import {DataType, TypedArray} from '../types';
21+
import {DataType, Rank, ShapeMap, TypedArray} from '../types';
2322

2423
// Required information for all backends.
2524
export interface BackendTimingInfo { kernelMs: number; }
@@ -158,6 +157,10 @@ export interface KernelBackend extends TensorStorage, BackendTimer {
158157
avgPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D;
159158
avgPoolBackprop(dy: Tensor4D, x: Tensor4D, convInfo: Conv2DInfo): Tensor4D;
160159

160+
reshape<T extends Tensor, R extends Rank>(x: T, shape: ShapeMap[R]):
161+
Tensor<R>;
162+
cast<T extends Tensor>(x: T, dtype: DataType): T;
163+
161164
tile<T extends Tensor>(x: T, reps: number[]): T;
162165

163166
pad<T extends Tensor>(

src/kernels/backend_cpu.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import {DataType, DataTypeMap, Rank, TypedArray} from '../types';
3232
import * as util from '../util';
3333

3434
import {BackendTimingInfo, KernelBackend} from './backend';
35+
import * as backend_util from './backend_util';
3536

3637
export class MathBackendCPU implements KernelBackend {
3738
private data = new WeakMap<DataId, DataTypeMap[DataType]>();
@@ -1316,6 +1317,15 @@ export class MathBackendCPU implements KernelBackend {
13161317
return dx.toTensor();
13171318
}
13181319

1320+
cast<T extends Tensor<types.Rank>>(x: T, dtype: DataType): T {
1321+
return backend_util.castTensor(x, dtype, this);
1322+
}
1323+
1324+
reshape<T extends Tensor<types.Rank>, R extends types.Rank>(
1325+
x: T, shape: types.ShapeMap[R]): Tensor<R> {
1326+
return backend_util.reshapeTensor(x, shape);
1327+
}
1328+
13191329
minPool(x: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
13201330
return this.pool(x, convInfo, 'min');
13211331
}

src/kernels/backend_util.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google Inc. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {Rank, Tensor, util} from '..';
19+
import {ArrayOps} from '../ops/array_ops';
20+
import {DataType, ShapeMap} from '../types';
21+
import {KernelBackend} from './backend';
22+
23+
export function castTensor<T extends Tensor<Rank>>(
24+
x: T, dtype: DataType, backend: KernelBackend): T {
25+
if (!util.hasEncodingLoss(x.dtype, dtype)) {
26+
// We don't change the underlying data, since we cast to higher
27+
// precision.
28+
return Tensor.make(x.shape, {dataId: x.dataId}, dtype) as T;
29+
}
30+
if (dtype === 'int32') {
31+
return backend.int(x);
32+
} else if (dtype === 'bool') {
33+
return backend.notEqual(x, ArrayOps.scalar(0, x.dtype)) as T;
34+
} else {
35+
throw new Error(`Error in Cast: unknown dtype argument (${dtype})`);
36+
}
37+
}
38+
39+
export function reshapeTensor<T extends Tensor<Rank>, R extends Rank>(
40+
x: T, shape: ShapeMap[R]): Tensor<R> {
41+
return Tensor.make(shape, {dataId: x.dataId}, x.dtype);
42+
}

src/kernels/backend_webgl.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ import {DataId, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D} from '../tensor'
2525
import * as types from '../types';
2626
import {DataType, DataTypeMap, RecursiveArray, TypedArray} from '../types';
2727
import * as util from '../util';
28+
2829
import {KernelBackend} from './backend';
30+
import * as backend_util from './backend_util';
2931
import {ArgMinMaxProgram} from './webgl/argminmax_gpu';
3032
import {AvgPool2DBackpropProgram} from './webgl/avg_pool_backprop_gpu';
3133
import {BatchNormProgram} from './webgl/batchnorm_gpu';
@@ -813,6 +815,15 @@ export class MathBackendWebGL implements KernelBackend {
813815
return this.compileAndRun(avgPoolBackpropProgram, [dy], output) as Tensor4D;
814816
}
815817

818+
cast<T extends Tensor<types.Rank>>(x: T, dtype: DataType): T {
819+
return backend_util.castTensor(x, dtype, this);
820+
}
821+
822+
reshape<T extends Tensor<types.Rank>, R extends types.Rank>(
823+
x: T, shape: types.ShapeMap[R]): Tensor<R> {
824+
return backend_util.reshapeTensor(x, shape);
825+
}
826+
816827
resizeBilinear(
817828
x: Tensor4D, newHeight: number, newWidth: number,
818829
alignCorners: boolean): Tensor4D {

src/ops/array_ops.ts

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {doc} from '../doc';
19-
import {ForwardFunc} from '../engine';
19+
// import {ForwardFunc} from '../engine';
2020
import {ENV} from '../environment';
2121
// tslint:disable-next-line:max-line-length
2222
import {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer} from '../tensor';
@@ -606,7 +606,7 @@ export class ArrayOps {
606606
return {x: () => dy.reshape(x.shape)};
607607
};
608608
return ENV.engine.runKernel(
609-
backend => Tensor.make(shape, {dataId: x.dataId}, x.dtype), {x}, grad);
609+
backend => backend.reshape(x, shape), {x}, grad);
610610
}
611611

612612
/**
@@ -640,24 +640,11 @@ export class ArrayOps {
640640
@doc({heading: 'Tensors', subheading: 'Transformations'})
641641
@operation
642642
static cast<T extends Tensor>(x: T, dtype: DataType): T {
643-
const forw: ForwardFunc<T> = backend => {
644-
if (!util.hasEncodingLoss(x.dtype, dtype)) {
645-
// We don't change the underlying data, since we cast to higher
646-
// precision.
647-
return Tensor.make(x.shape, {dataId: x.dataId}, dtype) as T;
648-
}
649-
if (dtype === 'int32') {
650-
return backend.int(x);
651-
} else if (dtype === 'bool') {
652-
return backend.notEqual(x, ArrayOps.scalar(0, x.dtype)) as T;
653-
} else {
654-
throw new Error(`Error in Cast: unknown dtype argument (${dtype})`);
655-
}
656-
};
657643
const grad = (dy: T) => {
658644
return {x: () => dy.clone()};
659645
};
660-
return ENV.engine.runKernel(forw, {x}, grad) as T;
646+
return ENV.engine.runKernel(backend => backend.cast(x, dtype), {x}, grad) as
647+
T;
661648
}
662649

663650
/**

0 commit comments

Comments
 (0)