Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pointwise scheduler fails to validate reference tv #3513

Merged
merged 77 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
f0ce0e3
will this work?
jjsjann123 Dec 2, 2024
70e31bf
errr
jjsjann123 Dec 2, 2024
5f09e36
missed a few renaming
jjsjann123 Dec 2, 2024
9ad9edb
WIP
jjsjann123 Dec 2, 2024
6540201
test added
jjsjann123 Dec 2, 2024
89dc741
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 4, 2024
ed56c75
WIP
jjsjann123 Dec 4, 2024
f1e7e0a
WIP
jjsjann123 Dec 4, 2024
9d174c9
declaration
jjsjann123 Dec 4, 2024
bf425eb
WIP
jjsjann123 Dec 4, 2024
aef13ac
WIP
jjsjann123 Dec 4, 2024
a9ae516
refactor the traversal
jjsjann123 Dec 4, 2024
d9e8dc0
WIP
jjsjann123 Dec 4, 2024
7333806
scratch that, it's getting out of hand
jjsjann123 Dec 4, 2024
f6ad363
Revert "scratch that, it's getting out of hand"
jjsjann123 Dec 4, 2024
cef0b83
try focus on expanded dimensions
jjsjann123 Dec 4, 2024
a557a8b
wip
jjsjann123 Dec 5, 2024
65cb621
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 5, 2024
f88ebf7
lintrunner
jjsjann123 Dec 5, 2024
ea89b69
comment added
jjsjann123 Dec 5, 2024
0fc0dc1
fixing
jjsjann123 Dec 5, 2024
07797a4
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 5, 2024
94e2ddf
Apply suggestions from code review
jjsjann123 Dec 6, 2024
3e2b43e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 6, 2024
63284b6
reverting unintended changes
jjsjann123 Dec 6, 2024
e39ec58
adding unit tests
jjsjann123 Dec 6, 2024
fa4d8ab
WIP
jjsjann123 Dec 6, 2024
66bc533
unit test
jjsjann123 Dec 6, 2024
26054c3
WIP
jjsjann123 Dec 6, 2024
3d2b926
WIP, seems to found another issue here
jjsjann123 Dec 7, 2024
bb659f8
revert unsafe exception
jjsjann123 Dec 7, 2024
45bb785
moving tests to uniform
jjsjann123 Dec 7, 2024
b744086
Revert "moving tests to uniform"
jjsjann123 Dec 7, 2024
3a16c65
do not use random for validation
jjsjann123 Dec 7, 2024
3b9c97f
fixing tests
jjsjann123 Dec 7, 2024
54176a7
fixing tests and comments
jjsjann123 Dec 7, 2024
5b668d6
skip the check for transpose scheduler to ensure no performance regre…
jjsjann123 Dec 9, 2024
3112ebd
allowing unmatched broadcast dimension
jjsjann123 Dec 9, 2024
db4cabc
CLANGFORMAT
jjsjann123 Dec 9, 2024
a60bdc6
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
55ddfc8
TYPO
jjsjann123 Dec 9, 2024
0972ca2
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 9, 2024
3e4cf86
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
325f5bb
lifting the broadcast exception, in case we change how expand is mode…
jjsjann123 Dec 9, 2024
8a92a15
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 9, 2024
880d73e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 9, 2024
22a7561
fixing false negative tests
jjsjann123 Dec 9, 2024
2bc97bf
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 10, 2024
49767da
WIP addressing review comments
jjsjann123 Dec 10, 2024
2cb3372
typo
jjsjann123 Dec 10, 2024
73c66f8
refactor the logic per review comments/discussions
jjsjann123 Dec 11, 2024
dbd5995
fixing signature
jjsjann123 Dec 11, 2024
b7f628f
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 11, 2024
4ba5baa
updating tests, removing asserts
jjsjann123 Dec 11, 2024
1d021c7
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 11, 2024
742f7f3
removing checks that are not exposed by scheduler
jjsjann123 Dec 11, 2024
145d902
renaming things
jjsjann123 Dec 11, 2024
19291c8
err somehow I missed this one
jjsjann123 Dec 11, 2024
526e6b7
updating tests
jjsjann123 Dec 12, 2024
0ecc1f6
adding another test
jjsjann123 Dec 12, 2024
6abaa1d
test fixing
jjsjann123 Dec 12, 2024
e8a4ddd
fixing tests
jjsjann123 Dec 12, 2024
6ada657
CLANGFORMAT
jjsjann123 Dec 12, 2024
d797df8
removing python test since it's already covered in cpp test
jjsjann123 Dec 12, 2024
25362cd
oops, assert was placed in the wrong spot
jjsjann123 Dec 12, 2024
a129e72
CLANGFORMAT
jjsjann123 Dec 12, 2024
b7f2efb
adding naoya's example
jjsjann123 Dec 12, 2024
307569f
I was padding the wrong dimension here
jjsjann123 Dec 12, 2024
af315c7
made a small refactor to avoid regression
jjsjann123 Dec 12, 2024
7668d4e
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 12, 2024
d46323c
committing something so I can trigger CI again
jjsjann123 Dec 12, 2024
c70f160
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
jjsjann123 Dec 12, 2024
4eee40b
Merge branch 'main' into pw_scheduler_reference_find_patch
jjsjann123 Dec 17, 2024
d7d62a9
Merge remote-tracking branch 'origin/main' into HEAD
jjsjann123 Dec 17, 2024
9a75f04
fixing renaming and updating scope
jjsjann123 Dec 17, 2024
b9b030b
fixing test include
jjsjann123 Dec 17, 2024
89b7c20
missed another scope change
jjsjann123 Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/pointwise_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ TensorView* PointwiseDomainMap::findReferenceTensor(
if (isValidReference(output_tv) &&
hasMinimumSize(output_tv, minimum_num_axes) &&
!output_tv->isFusionInput()) {
int64_t n_dims = pointwise_utils::nRootDims(output_tv);
int64_t n_dims = nLogicalDims(output_tv);
if (n_dims > max_dims) {
result = output_tv;
max_dims = n_dims;
Expand Down
2 changes: 1 addition & 1 deletion csrc/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace pointwise_utils {

// Returns number of non-reduction/non-broadcas/non-device dims in logical
// domain
inline int64_t nRootDims(const TensorView* tv) {
inline int64_t nLogicalDims(const TensorView* tv) {
auto logical_dom = tv->getLogicalDomain();
int64_t tv_n_dims = 0;
for (auto dim : logical_dom) {
Expand Down
147 changes: 146 additions & 1 deletion csrc/scheduler/tools/domain_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,137 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv)
return in_concrete_ids.empty();
}

// Note: ideally we would want to check that reference_tv contains all iter
// domains in target_tv, so that transformation applied on reference_tv can be
// propagated to target_tv. But we don't have an easy way to check that. Instead
// of that, this function checks that all source iter domains involved in
// transformation on target_tv is covered by reference_tv. Source iter domains
// of TensorViews are IDs that doesn't have an definition and are producers of
// any IDs on the logical domain of the given TensorView.
//
// ------
//
// e.g 0.
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// T198 [i0, b3(ex)*i1] = reshape(T192)
// output(T34)
// output(T198)
//
// if we consider taking T34 as reference_tv. T198 is the target_tv. We can't
// replay T34's transform of merging all the dimensions to T198, since b3(ex)*i1
// can't be reversed. The check in this function would give us T34 with source
// i0, i1; where T198 would have source i0, b3, i1, where b3 isn't contained in
// T34. Hence we'll reject this reference_tv.
//
// ------
//
// e.g 1.
// T0 [i0, i1]
// T1 [i2, i0, i1]
// T2 [i0*i1] = reshape(T0)
// T3 [b3, i0, i1] = broadcast(T0)
// T4 [i2, i0, i1] = add(T1, T3)
// output(T2)
// output(T4)
//
// the example above should be able to pick T4 as reference_tv. T2's source i0,
// i1 are both contained by the source of T4, so this example could be scheduled
// as a single fusion.
bool DomainMap::areAllTargetIdsCoveredBy(
TensorView* target_tv,
TensorView* reference_tv) const {
auto get_source_iter_domains = [this](TensorView* tv) {
// traverse back to collect all disjoint set producer IDs for each ID in the
// logical domain of tv.
VectorOfUniqueEntries<std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>>
all_producer_sets;
std::for_each(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
[&](IterDomain* tv_logical_id) {
all_producer_sets.pushBack(
ca_map_.disjointSetOf(tv_logical_id, IdMappingMode::EXACT));
});
all_producer_sets.pushBack(
ca_map_.getAllDisjointSetProducers(all_producer_sets));

std::vector<IterDomain*> source_ids;
// filtering all producer IDs with empty definition to get source iter
// domains
std::for_each(
all_producer_sets.vector().begin(),
all_producer_sets.vector().end(),
[&source_ids,
this](const std::shared_ptr<VectorOfUniqueEntries<IterDomain*>>&
producer_set_ptr) {
IterDomain* producer_id = producer_set_ptr->front();
if (ca_map_.uniqueExactDefinitions(producer_id).empty()) {
source_ids.push_back(producer_id);
}
});
return source_ids;
};

// this contains all source iter domain that's covered by reference_tv, so
// it's safe for target_tv to have them.
std::unordered_set<IterDomain*> covered_source_ids;
for (IterDomain* source_id_ref : get_source_iter_domains(reference_tv)) {
covered_source_ids.insert(source_id_ref);
}
// It's safe to have unmapped broadcast IterDomain. There're quite a few tests
// expecting pointwise scheduler to handle this pattern
for (IterDomain* id_out : target_tv->getLogicalDomain()) {
if (id_out->isBroadcast()) {
NVF_ERROR(
id_out->definition() == nullptr ||
id_out->definition()->isA<Resize>());

// Note that ideally we should also be able to handle merge/split on
// broadcast IDs, so we should really move this skip inside the loop below
// `get_source_iter_domains(target_tv)` and skip broadcast source IDs.
// currently we have the issue that split/merge does not preserve expanded
// broadcasts, see issue: https://github.com/NVIDIA/Fuser/issues/1126
covered_source_ids.insert(id_out);
}
}
// Note: there's certain cases where it's safe to have dangling IDs,
// e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T192 [i0, b3(ex), i1] = expand(T185)
// It's safe to propagate T34 to T192, since b3(ex) is not involved in the
// propagation. But this isn't generally safe. If the above example is changed
// to e.g
// T34 [i0, i1]
// T185 [i0, b2, i1] = broadcast(T34)
// T186 [i0, i4, i1] = ones({i0, i4, i1})
// T193 [i0, i4, i1] = add(T185, T186)
// It's unsafe to propagate from T34 to T193, see issue
// https://github.com/NVIDIA/Fuser/issues/3542

// Check all source iter domain involved in producing target_tv
for (IterDomain* source_id_out : get_source_iter_domains(target_tv)) {
// NOTE: we use concrete id instead. This allows us to link indirect
// broadcast. So in the example below: T2[i0, i1] = T0[i0, b0] + T1[i0, i1]
// T3[i0, i9] = pad(T0[i0, b0])
// We have i9 in T3
// -> source ID b0
// -> concrete map to i1
// So T3 is contained by T2. See test `PointwiseTest.DomainMapPad1`
auto concrete_source_id_out =
ca_map_.getConcreteMappedID(source_id_out, IdMappingMode::PERMISSIVE);
// if we find any source_id_out that's not contained, it's possible our
// propagation would fail since transformation involving this iter domain
// can't be resolved.
if (!getMappedInputConcreteID(covered_source_ids, concrete_source_id_out)) {
return false;
}
}
return true;
}

// Reference domains must exactly match with the input domains. See
// also PR #661
IterDomain* DomainMap::getMappedInputConcreteID(
Expand Down Expand Up @@ -228,7 +359,7 @@ IterDomain* DomainMap::anyMapped(
}

// Determine if output TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and output
bool DomainMap::isValidReference(TensorView* tv) const {
for (auto input_tv : ir_utils::filterByType<TensorView>(fusion_->inputs())) {
if (input_tv->uses().empty()) {
Expand All @@ -240,6 +371,20 @@ bool DomainMap::isValidReference(TensorView* tv) const {
return false;
}
}
// The check on outputs are optional, transpose scheduler might propose a
// secondary reference that only applies to a subset of IO tensors. Ideally we
// should have a more robust check and consider the IO groups instead of
// blindly skip outputs.
for (auto output_tv :
ir_utils::filterByType<TensorView>(fusion_->outputs())) {
// no need to check for self.
if (output_tv == tv) {
continue;
}
if (!areAllTargetIdsCoveredBy(output_tv, tv)) {
return false;
}
}
return true;
}

Expand Down
9 changes: 8 additions & 1 deletion csrc/scheduler/tools/domain_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,21 @@ class DomainMap {
}

// Determine if a TensorView is a valid reference tensor for this fusion.
// The reference tensor must map to all the iterDomains in each input.
// The reference tensor must map to all the iterDomains in each input and
// output.
bool isValidReference(TensorView* tv) const;

protected:
// Determine if all IterDomains are mapped between input and the given tvs
bool areAllInputIdsMappedTo(TensorView* input_tv, TensorView* output_tv)
const;

// Determine if all source IterDomains in target_tv are contained by the
// reference_tv, this ensures transformations from reference_tv can be
// propagated to target_tv
bool areAllTargetIdsCoveredBy(TensorView* target_tv, TensorView* reference_tv)
const;

virtual IterDomain* getMappedInputConcreteID(
const std::unordered_set<IterDomain*>& in_concrete_ids,
IterDomain* out_id) const;
Expand Down
18 changes: 13 additions & 5 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,16 @@ class TransposeDomainMap : public scheduler_tools::DomainMap {
TensorView* result = nullptr;
int64_t max_dims = -1;
for (auto tv : group) {
// since transpose scheduler have different set of reference, we skip IDs
jjsjann123 marked this conversation as resolved.
Show resolved Hide resolved
// coverage check of the reference on outputs of the fusion. Note that
// this is not ideal, we would want to instead have reference tensor
// checked against all its target IO tensors.
// TODO: open an issue for this one. transpose scheduler is not supposed
// to reuse pointwise_utils::DomainMap::isValidRefrence. This function is
// too restrictive and doesn't align well with the scheme of transpose
// scheduler
if (isValidReference(tv)) {
int64_t dims = (int64_t)pointwise_utils::nRootDims(tv);
int64_t dims = (int64_t)pointwise_utils::nLogicalDims(tv);
if (dims > max_dims) {
result = tv;
max_dims = dims;
Expand Down Expand Up @@ -992,12 +1000,12 @@ std::unique_ptr<TransposeParams> getTransposeHeuristics(
<< "max_io_dtype_size: " << max_io_dtype_size << "\n"
<< "group 1: " << ir_utils::toString(grouped_inputs_outputs[0])
<< "\n"
<< "reference1: " << reference1 << "\n"
<< "reference1: " << reference1->toString() << "\n"
<< "inner_most_id1 position: " << inner_most_pos1_in_ref1
<< " (in reference 1)\n"
<< "group 2: " << ir_utils::toString(grouped_inputs_outputs[1])
<< "\n"
<< "reference2: " << reference2 << "\n"
<< "reference2: " << reference2->toString() << "\n"
<< "inner_most_id2 position: " << inner_most_pos2_in_ref1
<< " (in reference 1)" << std::endl;
if (hasSmallTransposeDimensions(tparams)) {
Expand Down Expand Up @@ -1047,11 +1055,11 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams* tparams) {

int64_t max_dims = 0;
for (auto inp : input_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(inp), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(inp), max_dims);
}

for (auto out : output_tvs) {
max_dims = std::max(pointwise_utils::nRootDims(out), max_dims);
max_dims = std::max(pointwise_utils::nLogicalDims(out), max_dims);
}

// If everything is zero dim tensors, just return.
Expand Down
Loading
Loading