forked from NVIDIA-RTX/RTXNTC-Library
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathWeightLayout.h
More file actions
56 lines (45 loc) · 1.25 KB
/
WeightLayout.h
File metadata and controls
56 lines (45 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include <libntc/ntc.h>
namespace ntc
{
class GraphicsResources;
enum class DataType
{
None,
Int8,
Int32,
FP8,
FP16,
FP32
};
struct Span
{
size_t offset = 0;
size_t size = 0;
DataType type = DataType::None;
};
struct WeightLayout
{
Span weights[NTC_MLP_LAYERS]{};
Span combinedWeights;
Span scales[NTC_MLP_LAYERS]{};
Span biases[NTC_MLP_LAYERS]{};
Span combinedScaleBias;
size_t bufferSize = 0;
};
size_t GetDataTypeSize(DataType type);
bool MakeQuantizedWeightLayout(GraphicsResources const* resources,
InferenceWeightType weightType, WeightLayout& outLayout);
void MakeFP16WeightLayout(WeightLayout& outLayout);
} // namespace ntc