forked from onnx/onnx-tensorrt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathResizeNearest.hpp
65 lines (61 loc) · 2.66 KB
/
ResizeNearest.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/*
* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#pragma once
#include "plugin.hpp"
#include "serialize.hpp"
#include <cassert>
class ResizeNearestPlugin final : public onnx2trt::Plugin {
int _ndims;
float _scale[nvinfer1::Dims::MAX_DIMS];
nvinfer1::Dims _output_dims;
protected:
void deserialize(void const* serialData, size_t serialLength) {
deserializeBase(serialData, serialLength);
deserialize_value(&serialData, &serialLength, &_ndims);
deserialize_value(&serialData, &serialLength, &_scale);
}
size_t getSerializationSize() override {
return serialized_size(_ndims) + serialized_size(_scale) + getBaseSerializationSize();
}
void serialize(void *buffer) override {
serializeBase(buffer);
serialize_value(&buffer, _ndims);
serialize_value(&buffer, _scale);
}
public:
ResizeNearestPlugin(std::vector<float> const& scale)
: _ndims(scale.size()) {
assert(scale.size() <= nvinfer1::Dims::MAX_DIMS);
std::copy(scale.begin(), scale.end(), _scale);
}
ResizeNearestPlugin(void const* serialData, size_t serialLength) {
this->deserialize(serialData, serialLength);
}
virtual const char* getPluginType() const override { return "ResizeNearest"; }
virtual int getNbOutputs() const override { return 1; }
virtual nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims *inputs, int nbInputDims) override;
virtual int initialize() override;
int enqueue(int batchSize,
const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override;
};