@@ -54,21 +54,19 @@ op_sp_sum::apply(Mat<typename T1::elem_type>& out, const mtSpReduceOp<typename T
5454
5555 const uword N = p.get_n_nonzero ();
5656
57- for (uword i=0 ; i < N; ++i)
58- {
59- out_mem[it.col ()] += (*it);
60- ++it;
61- }
57+ for (uword i=0 ; i < N; ++i) { out_mem[it.col ()] += (*it); ++it; }
6258 }
6359 else
6460 {
61+ const eT* values = p.get_values ();
62+ const uword* colptrs = p.get_col_ptrs ();
63+
6564 for (uword col = 0 ; col < p_n_cols; ++col)
6665 {
67- out_mem[col] = arrayops::accumulate
68- (
69- &p.get_values ()[p.get_col_ptrs ()[col]],
70- p.get_col_ptrs ()[col + 1 ] - p.get_col_ptrs ()[col]
71- );
66+ const eT* coldata = &(values[ colptrs[col] ]);
67+ const uword N = colptrs[col + 1 ] - colptrs[col];
68+
69+ out_mem[col] = arrayops::accumulate (coldata, N);
7270 }
7371 }
7472 }
@@ -79,11 +77,69 @@ op_sp_sum::apply(Mat<typename T1::elem_type>& out, const mtSpReduceOp<typename T
7977
8078 const uword N = p.get_n_nonzero ();
8179
82- for (uword i=0 ; i < N; ++i)
80+ for (uword i=0 ; i < N; ++i) { out_mem[it.row ()] += (*it); ++it; }
81+ }
82+ }
83+
84+
85+
86+ template <typename T1>
87+ inline
88+ void
89+ op_sp_sum::apply (Mat<typename T1::elem_type>& out, const mtSpReduceOp<typename T1::elem_type, SpOp<T1, spop_square>, op_sp_sum>& in)
90+ {
91+ arma_debug_sigprint ();
92+
93+ typedef typename T1::elem_type eT;
94+
95+ const uword dim = in.aux_uword_a ;
96+
97+ arma_conform_check ( (dim > 1 ), " sum(): parameter 'dim' must be 0 or 1" );
98+
99+ const SpProxy<T1> p (in.m .m );
100+
101+ const uword p_n_rows = p.get_n_rows ();
102+ const uword p_n_cols = p.get_n_cols ();
103+
104+ if (dim == 0 ) { out.zeros (1 , p_n_cols); }
105+ if (dim == 1 ) { out.zeros (p_n_rows, 1 ); }
106+
107+ if (p.get_n_nonzero () == 0 ) { return ; }
108+
109+ eT* out_mem = out.memptr ();
110+
111+ if (dim == 0 ) // find the sum of squares in each column
112+ {
113+ if (SpProxy<T1>::use_iterator)
83114 {
84- out_mem[it.row ()] += (*it);
85- ++it;
115+ typename SpProxy<T1>::const_iterator_type it = p.begin ();
116+
117+ const uword N = p.get_n_nonzero ();
118+
119+ for (uword i=0 ; i < N; ++i) { const eT val = (*it); out_mem[it.col ()] += (val*val); ++it; }
86120 }
121+ else
122+ {
123+ const eT* values = p.get_values ();
124+ const uword* colptrs = p.get_col_ptrs ();
125+
126+ for (uword col = 0 ; col < p_n_cols; ++col)
127+ {
128+ const eT* coldata = &(values[ colptrs[col] ]);
129+ const uword N = colptrs[col + 1 ] - colptrs[col];
130+
131+ out_mem[col] = op_dot::direct_dot (N, coldata, coldata);
132+ }
133+ }
134+ }
135+ else
136+ if (dim == 1 ) // find the sum of squares in each row
137+ {
138+ typename SpProxy<T1>::const_iterator_type it = p.begin ();
139+
140+ const uword N = p.get_n_nonzero ();
141+
142+ for (uword i=0 ; i < N; ++i) { const eT val = (*it); out_mem[it.row ()] += (val*val); ++it; }
87143 }
88144 }
89145
0 commit comments