Skip to content

Commit 324e6a5

Browse files
committed
libpasched:
- add disjoint_set data structure - add ostream operator<< for schedule_unit - rewrite simplify_order_cuts to be much more effective - fix strip_dataless_units to generate connected graphs only (the transformation can disconnect the graph !)
1 parent 113b30d commit 324e6a5

File tree

6 files changed

+225
-64
lines changed

6 files changed

+225
-64
lines changed

libpasched/include/libpasched/sched-transform.hpp

-4
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,6 @@ class simplify_order_cuts : public transformation
257257

258258
virtual void transform(schedule_dag& d, const scheduler& s, schedule_chain& c,
259259
transformation_status& status) const;
260-
261-
protected:
262-
void do_transform(schedule_dag& d, const scheduler& s, schedule_chain& c,
263-
transformation_status& status, int level) const;
264260
};
265261

266262
/**

libpasched/include/libpasched/sched-unit.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class schedule_unit
116116
virtual unsigned internal_register_pressure() const = 0;
117117
};
118118

119+
std::ostream& operator<<(std::ostream& os, const schedule_unit *u);
120+
119121
class chain_schedule_unit : public schedule_unit
120122
{
121123
public:

libpasched/include/libpasched/tools.hpp

+57
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,68 @@
66
#include <functional>
77
#include <string>
88
#include <set>
9+
#include <map>
910
#include <iostream>
1011

1112
namespace PAMAURY_SCHEDULER_NS
1213
{
1314

15+
/* Union-Find structure */
16+
template< typename T >
17+
class disjoint_set
18+
{
19+
public:
20+
disjoint_set() {}
21+
template< typename U >
22+
disjoint_set(const U& u) { set_elements(u); }
23+
template< typename U >
24+
disjoint_set(const U& begin, const U& end) { set_elements(begin, end); }
25+
26+
template< typename U >
27+
void set_elements(const U& u) { set_elements(u.begin(), u.end()); }
28+
29+
template< typename U >
30+
void set_elements(U begin, U end)
31+
{
32+
m_parent.clear();
33+
m_height.clear();
34+
while(begin != end)
35+
{
36+
m_parent[*begin] = *begin;
37+
m_height[*begin] = 1;
38+
++begin;
39+
}
40+
}
41+
42+
void merge(const T& _a, const T& _b)
43+
{
44+
T a = find(_a);
45+
T b = find(_b);
46+
47+
if(m_height[a] < m_height[b])
48+
m_parent[a] = b;
49+
else if(m_height[a] > m_height[b])
50+
m_parent[b] = a;
51+
else
52+
{
53+
m_parent[a] = b;
54+
m_height[a]++;
55+
}
56+
}
57+
58+
T find(const T& a) const
59+
{
60+
if(m_parent[a] == a)
61+
return a;
62+
m_parent[a] = find(m_parent[a]);
63+
return m_parent[a];
64+
}
65+
66+
protected:
67+
mutable std::map< T, T > m_parent;
68+
std::map< T, size_t > m_height;
69+
};
70+
1471
template<typename T>
1572
std::ostream& operator<<(std::ostream& os, const std::set< T >& s)
1673
{

libpasched/src/sched-transform.cpp

+159-57
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ namespace PAMAURY_SCHEDULER_NS
4747
for(size_t i = __debug_old_size; i < chain.get_unit_count(); i++) \
4848
assert(container_contains(__debug_units, chain.get_unit_at(i)));
4949
#else
50-
#define DEBUG_CHECK_BEGIN_X(dag, chain)
51-
#define DEBUG_CHECK_END_X(chain)
50+
#define DEBUG_CHECK_BEGIN_X(dag, chain) schedule_dag *__cpy = dag.dup();
51+
#define DEBUG_CHECK_END_X(chain) delete __cpy;
5252
#endif
5353

5454
namespace PAMAURY_SCHEDULER_NS
@@ -665,8 +665,6 @@ void smart_fuse_two_units::transform(schedule_dag& dag, const scheduler& s, sche
665665

666666
Lgraph_changed:
667667
modified = true;
668-
if(fused.size() == 1)
669-
break;
670668
}
671669

672670
XTM_FW_STOP(smart_fuse_two_units)
@@ -750,6 +748,7 @@ bool smart_fuse_two_units::weak_fuse(schedule_dag& dag, const schedule_unit *a,
750748
* simplify_order_cuts
751749
*/
752750
XTM_FW_DECLARE(simplify_order_cuts)
751+
XTM_BW_DECLARE(simplify_order_cuts)
753752

754753
simplify_order_cuts::simplify_order_cuts()
755754
{
@@ -759,82 +758,157 @@ simplify_order_cuts::~simplify_order_cuts()
759758
{
760759
}
761760

761+
namespace
762+
{
763+
enum visit_state_t
764+
{
765+
vs_not_processed,
766+
vs_processing,
767+
vs_processed
768+
};
769+
770+
const schedule_unit *collapse_order_cycles(
771+
disjoint_set< const schedule_unit * >& uf,
772+
std::map< const schedule_unit *, std::set< const schedule_unit * > >& order_deps,
773+
schedule_dag& dag,
774+
std::map< const schedule_unit *, visit_state_t >& state,
775+
const schedule_unit *u)
776+
{
777+
if(state[u] == vs_processed)
778+
return 0;
779+
if(state[u] == vs_processing)
780+
return u;
781+
state[u] = vs_processing;
782+
783+
std::set< const schedule_unit * >::iterator it;
784+
for(it = order_deps[u].begin(); it != order_deps[u].end(); ++it)
785+
{
786+
const schedule_unit *ret = collapse_order_cycles(uf, order_deps, dag, state, *it);
787+
if(ret == 0)
788+
continue;
789+
uf.merge(u, ret);
790+
return ret;
791+
}
792+
793+
state[u] = vs_processed;
794+
return 0;
795+
}
796+
797+
bool collapse_order_cycles(
798+
disjoint_set< const schedule_unit * >& uf,
799+
schedule_dag& dag)
800+
{
801+
std::map< const schedule_unit *, visit_state_t > state;
802+
std::map< const schedule_unit *, std::set< const schedule_unit * > > order_deps;
803+
804+
for(size_t i = 0; i < dag.get_deps().size(); i++)
805+
{
806+
const schedule_dep& d = dag.get_deps()[i];
807+
if(!d.is_order())
808+
continue;
809+
if(uf.find(d.from()) != uf.find(d.to()))
810+
order_deps[uf.find(d.from())].insert(uf.find(d.to()));
811+
}
812+
813+
for(size_t i = 0; i < dag.get_units().size(); i++)
814+
if(collapse_order_cycles(uf, order_deps, dag, state, uf.find(dag.get_units()[i])) != 0)
815+
return true;
816+
return false;
817+
}
818+
}
819+
762820
void simplify_order_cuts::transform(schedule_dag& dag, const scheduler& s, schedule_chain& c,
763821
transformation_status& status) const
764822
{
765823
DEBUG_CHECK_BEGIN_X(dag, c)
766824
XTM_FW_START(simplify_order_cuts)
767-
do_transform(dag, s, c, status, 0);
768-
XTM_FW_STOP(simplify_order_cuts)
769-
DEBUG_CHECK_END_X(c)
770-
}
771825

772-
void simplify_order_cuts::do_transform(schedule_dag& dag, const scheduler& s, schedule_chain& c,
773-
transformation_status& status, int level) const
774-
{
775-
if(level == 0)
826+
/* compute a companion graph where nodes are connected components w.r.t data
827+
* deps and edges are order deps */
828+
disjoint_set< const schedule_unit * > uf(dag.get_units());
829+
830+
for(size_t i = 0; i < dag.get_deps().size(); i++)
776831
{
777-
debug() << "---> simplify_order_cuts::transform\n";
778-
status.begin_transformation();
832+
const schedule_dep& d = dag.get_deps()[i];
833+
if(d.is_data())
834+
uf.merge(d.from(), d.to());
779835
}
780-
781-
for(size_t u = 0; u < dag.get_units().size(); u++)
836+
837+
/* Collapse all cycles */
838+
while(collapse_order_cycles(uf, dag))
782839
{
783-
/* pick the first unit */
784-
const schedule_unit *unit = dag.get_units()[u];
785-
/* Compute the largest component C which
786-
* contains U and which is stable by these operations:
787-
* 1) If A is in C and A->B is a data dep, B is in C
788-
* 2) If A is in C and B->A is a data dep, B is in C
789-
* 2) If A is in C and B->A is an order dep, B is in C
790-
*/
791-
std::set< const schedule_unit * > reach = dag.get_reachable(unit,
792-
schedule_dag::rf_include_unit | schedule_dag::rf_follow_preds | schedule_dag::rf_follow_succs_data);
840+
}
793841

794-
/* handle trivial case where the whole graph is reachable */
795-
if(reach.size() == dag.get_units().size())
796-
continue;
842+
/* build reduced dag */
843+
std::map< const schedule_unit *, std::set< const schedule_unit * > > sets;
844+
for(size_t i = 0; i < dag.get_units().size(); i++)
845+
sets[uf.find(dag.get_units()[i])].insert(dag.get_units()[i]);
797846

798-
/* extract this subgraph for further analysis */
799-
schedule_dag *top = dag.dup_subgraph(reach);
800-
dag.remove_units(set_to_vector(reach));
847+
/* special trivial case where there is not cut */
848+
if(sets.size() == 1)
849+
{
850+
status.set_modified_graph(false);
851+
status.set_deadlock(false);
852+
status.set_junction(false);
801853

802-
if(level == 0)
803-
{
804-
status.set_modified_graph(true);
805-
status.set_junction(true); /* don't set deadlock then ! */
806-
}
807-
/* recursively transform top */
808-
do_transform(*top, s, c, status, level + 1);
809-
/* and then bottom */
810-
do_transform(dag, s, c, status, level + 1);
854+
XTM_FW_STOP(simplify_order_cuts)
811855

812-
if(level == 0)
813-
{
814-
status.end_transformation();
815-
debug() << "<--- simplify_order_cuts::transform\n";
816-
}
856+
s.schedule(dag, c);
857+
858+
DEBUG_CHECK_END_X(c)
817859
return;
818860
}
819861

820-
if(level == 0)
862+
/* otherwise do some work */
863+
status.set_modified_graph(true);
864+
status.set_deadlock(false);
865+
status.set_junction(true);
866+
867+
std::map< const schedule_unit *, std::set< const schedule_unit * > >::iterator it;
868+
for(it = sets.begin(); it != sets.end(); ++it)
821869
{
822-
status.set_modified_graph(false);
823-
status.set_deadlock(false);
824-
status.set_junction(false);
870+
//std::cout << " " << it->second << "\n";
871+
/* schedule each subgraph */
872+
schedule_dag *sub = dag.dup_subgraph(it->second);
873+
generic_schedule_chain gsc;
874+
875+
XTM_FW_STOP(simplify_order_cuts)
876+
877+
s.schedule(*sub, gsc);
878+
879+
XTM_FW_START(simplify_order_cuts)
880+
881+
/* create a chain unit and collapse in the reduced sub graph */
882+
chain_schedule_unit *csu = new chain_schedule_unit;
883+
csu->get_chain() = gsc.get_units();
884+
csu->set_internal_register_pressure(c.compute_rp_against_dag(*sub));
885+
886+
dag.collapse_subgraph(it->second, csu);
887+
/* release memory */
888+
delete sub;
825889
}
826890

827891
XTM_FW_STOP(simplify_order_cuts)
828-
/* otherwise, schedule the whole graph */
829-
s.schedule(dag, c);
830892

831-
XTM_FW_START(simplify_order_cuts)
893+
/* Schedule the reduced graph with the rand scheduler because there are only
894+
* order deps. We can't schedule with s otherwise we'll loop */
895+
generic_schedule_chain gsc;
896+
rand_scheduler rs;
897+
rs.schedule(dag, gsc);
898+
899+
XTM_BW_START(simplify_order_cuts)
832900

833-
if(level == 0)
901+
/* expand back chain units */
902+
for(size_t i = 0; i < gsc.get_unit_count(); i++)
834903
{
835-
status.end_transformation();
836-
debug() << "<--- simplify_order_cuts::transform\n";
904+
const chain_schedule_unit *csu = static_cast< const chain_schedule_unit * >(gsc.get_unit_at(i));
905+
c.insert_units_at(c.get_unit_count(), csu->get_chain());
906+
delete csu;
837907
}
908+
909+
XTM_BW_STOP(simplify_order_cuts)
910+
911+
DEBUG_CHECK_END_X(c)
838912
}
839913

840914
/**
@@ -1528,6 +1602,33 @@ strip_dataless_units::~strip_dataless_units()
15281602
{
15291603
}
15301604

1605+
namespace
1606+
{
1607+
void split_cc_and_schedule(const scheduler& s, schedule_dag& dag, schedule_chain& c, transformation_status& status)
1608+
{
1609+
while(dag.get_units().size() > 0)
1610+
{
1611+
/* get reachable set of the first root */
1612+
std::set< const schedule_unit * > set =
1613+
dag.get_reachable(dag.get_roots()[0],
1614+
schedule_dag::rf_follow_preds | schedule_dag::rf_follow_succs | schedule_dag::rf_include_unit);
1615+
/* stop if it's the entire graph */
1616+
if(set.size() == dag.get_units().size())
1617+
break;
1618+
/* extract subgraph */
1619+
status.set_modified_graph(true);
1620+
status.set_junction(true);
1621+
1622+
schedule_dag *sub = dag.dup_subgraph(set);
1623+
s.schedule(*sub, c);
1624+
delete sub;
1625+
/* delete from the graph */
1626+
dag.remove_units(set_to_vector(set));
1627+
}
1628+
s.schedule(dag, c);
1629+
}
1630+
}
1631+
15311632
void strip_dataless_units::transform(schedule_dag& dag, const scheduler& s, schedule_chain& c,
15321633
transformation_status& status) const
15331634
{
@@ -1609,7 +1710,8 @@ void strip_dataless_units::transform(schedule_dag& dag, const scheduler& s, sche
16091710
status.set_modified_graph(stripped.size() > 0);
16101711
status.set_deadlock(false);
16111712
status.set_junction(false);
1612-
s.schedule(dag, c);
1713+
1714+
split_cc_and_schedule(s, dag, c, status);
16131715

16141716
XTM_BW_START(strip_dataless_units)
16151717

libpasched/src/sched-unit.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ schedule_unit::~schedule_unit()
2727
{
2828
}
2929

30+
std::ostream& operator<<(std::ostream& os, const schedule_unit *u)
31+
{
32+
return os << u->to_string();
33+
}
34+
3035
/**
3136
* chain_schedule_unit
3237
*/

0 commit comments

Comments
 (0)