-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSegSort.h
106 lines (91 loc) · 2.68 KB
/
SegSort.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#pragma once
#include <utility>
#include <chrono>
using namespace std::chrono;
#include <webgpu/webgpu_cpp.h>
#include "wgpu/WGPUHelpers.h"
const int COPY_STATUS_OFFSET = 8192;
struct Param {
uint32_t count;
uint32_t nt;
uint32_t vt;
uint32_t nt2;
uint32_t num_partitions;
uint32_t num_segments;
uint32_t num_ranges;
uint32_t num_partition_ctas;
uint32_t max_num_passes;
};
class SegmentedSort {
public:
void Dispose();
void Init(
const wgpu::Device& device,
const wgpu::Buffer& inputBuffer,
uint32_t maxInputSize,
const wgpu::Buffer& segmentBuffer,
uint32_t maxSegmentSize
);
void Clear(const wgpu::CommandEncoder& encoder);
void Upload(
const wgpu::Device& device,
uint32_t count,
uint32_t segmentCount);
void Sort(const wgpu::CommandEncoder& encoder, const wgpu::QuerySet& querySet, uint32_t count, uint32_t segmentCount);
private:
const uint32_t nt = 128;
const uint32_t nt2 = 64;
const uint32_t vt = 15;
const uint32_t nv = 1920;
uint32_t maxCount;
uint32_t maxNumPasses;
uint32_t maxNumSegments;
uint32_t maxNumCtas;
uint32_t maxCapacity;
uint32_t previousCount = 0;
void InitBuffers(const wgpu::Device& device);
void InitClear(const wgpu::Device& device);
void InitPartition(
const wgpu::Device& device,
const wgpu::Buffer& inputBuffer
);
void InitCopy(
const wgpu::Device& device,
const wgpu::Buffer& inputBuffer
);
void InitBlock(
const wgpu::Device& device,
const wgpu::Buffer& inputBuffer,
const wgpu::Buffer& segmentsBuffer
);
void InitBinarySearch(
const wgpu::Device& device,
const wgpu::Buffer& segmentsBuffer
);
void InitMerge(
const wgpu::Device& device,
const wgpu::Buffer& inputBuffer
);
wgpu::Buffer inputBufferCopy;
wgpu::Buffer paramBuffer;
wgpu::Buffer partitionBuffer;
wgpu::Buffer passCountBuffer;
wgpu::Buffer compressedRangesBuffer;
wgpu::Buffer mergeRangesBuffer;
wgpu::Buffer copyListBuffer;
wgpu::Buffer opCounterBuffer;
wgpu::Buffer mergeListBuffer;
wgpu::ComputePipeline blockPipeline[2];
wgpu::ComputePipeline partitionPipeline;
wgpu::ComputePipeline mergePipeline;
wgpu::ComputePipeline binarySearchPipeline;
wgpu::ComputePipeline copyPipeline;
wgpu::ComputePipeline clearPipeline;
wgpu::BindGroup copyBindGroups[2];
wgpu::BindGroup binarySearchBindGroup;
wgpu::BindGroup blockBindGroups[2];
wgpu::BindGroup partitionBindGroups[2];
wgpu::BindGroup mergeBindGroups[2];
wgpu::BindGroup clearBindGroup;
Param params;
};