-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathprpack_solver.h
177 lines (172 loc) · 6.77 KB
/
prpack_solver.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#ifndef PRPACK_SOLVER
#define PRPACK_SOLVER
#include "prpack_base_graph.h"
#include "prpack_csc.h"
#include "prpack_csr.h"
#include "prpack_edge_list.h"
#include "prpack_preprocessed_ge_graph.h"
#include "prpack_preprocessed_gs_graph.h"
#include "prpack_preprocessed_scc_graph.h"
#include "prpack_preprocessed_schur_graph.h"
#include "prpack_result.h"
// TODO Make this a user configurable variable
#define PRPACK_SOLVER_MAX_ITERS 1000000
namespace prpack {
// Solver class.
class prpack_solver {
private:
// instance variables
double read_time;
prpack_base_graph* bg;
prpack_preprocessed_ge_graph* geg;
prpack_preprocessed_gs_graph* gsg;
prpack_preprocessed_schur_graph* sg;
prpack_preprocessed_scc_graph* sccg;
// methods
void initialize();
static prpack_result* solve_via_ge(
const double alpha,
const double tol,
const int num_vs,
const double* matrix,
const double* uv);
static prpack_result* solve_via_ge_uv(
const double alpha,
const double tol,
const int num_vs,
const double* matrix,
const double* d,
const double* u,
const double* v);
static prpack_result* solve_via_gs(
const double alpha,
const double tol,
const int num_vs,
const int num_es,
const int* heads,
const int* tails,
const double* vals,
const double* ii,
const double* d,
const double* num_outlinks,
const double* u,
const double* v);
static prpack_result* solve_via_gs_err(
const double alpha,
const double tol,
const int num_vs,
const int num_es,
const int* heads,
const int* tails,
const double* ii,
const double* num_outlinks,
const double* u,
const double* v);
static prpack_result* solve_via_schur_gs(
const double alpha,
const double tol,
const int num_vs,
const int num_no_in_vs,
const int num_no_out_vs,
const int num_es,
const int* heads,
const int* tails,
const double* vals,
const double* ii,
const double* d,
const double* num_outlinks,
const double* uv,
const int* encoding,
const int* decoding,
const bool should_normalize = true);
static prpack_result* solve_via_schur_gs_uv(
const double alpha,
const double tol,
const int num_vs,
const int num_no_in_vs,
const int num_no_out_vs,
const int num_es,
const int* heads,
const int* tails,
const double* vals,
const double* ii,
const double* d,
const double* num_outlinks,
const double* u,
const double* v,
const int* encoding,
const int* decoding);
static prpack_result* solve_via_scc_gs(
const double alpha,
const double tol,
const int num_vs,
const int num_es_inside,
const int* heads_inside,
const int* tails_inside,
const double* vals_inside,
const int num_es_outside,
const int* heads_outside,
const int* tails_outside,
const double* vals_outside,
const double* ii,
const double* d,
const double* num_outlinks,
const double* uv,
const int num_comps,
const int* divisions,
const int* encoding,
const int* decoding,
const bool should_normalize = true);
static prpack_result* solve_via_scc_gs_uv(
const double alpha,
const double tol,
const int num_vs,
const int num_es_inside,
const int* heads_inside,
const int* tails_inside,
const double* vals_inside,
const int num_es_outside,
const int* heads_outside,
const int* tails_outside,
const double* vals_outside,
const double* ii,
const double* d,
const double* num_outlinks,
const double* u,
const double* v,
const int num_comps,
const int* divisions,
const int* encoding,
const int* decoding);
static void ge(const int sz, double* A, double* b);
static void normalize(const int length, double* x);
static prpack_result* combine_uv(
const int num_vs,
const double* d,
const double* num_outlinks,
const int* encoding,
const double alpha,
const prpack_result* ret_u,
const prpack_result* ret_v);
public:
// constructors
prpack_solver(const prpack_csc* g);
prpack_solver(const prpack_int64_csc* g);
prpack_solver(const prpack_csr* g);
prpack_solver(const prpack_edge_list* g);
prpack_solver(prpack_base_graph* g);
prpack_solver(const char* filename, const char* format, const bool weighted);
// destructor
~prpack_solver();
// methods
int get_num_vs();
prpack_result* solve(const double alpha, const double tol, const char* method);
prpack_result* solve(
const double alpha,
const double tol,
const double* u,
const double* v,
const char* method);
};
};
#endif