@@ -12,24 +12,45 @@ import { DataViewTableHead } from '../DataViewTableHead';
1212import { DataViewTh , DataViewTrTree , isDataViewTdObject } from '../DataViewTable' ;
1313import { DataViewState } from '../DataView/DataView' ;
1414
15- const getDescendants = ( node : DataViewTrTree ) : DataViewTrTree [ ] => ( ! node . children || ! node . children . length ) ? [ node ] : node . children . flatMap ( getDescendants ) ;
16-
17- const isNodeChecked = ( node : DataViewTrTree , isSelected : ( node : DataViewTrTree ) => boolean ) => {
18- let allSelected = true ;
19- let someSelected = false ;
20-
21- for ( const descendant of getDescendants ( node ) ) {
22- const selected = ! ! isSelected ?.( descendant ) ;
23-
24- someSelected ||= selected ;
25- allSelected &&= selected ;
26-
27- if ( ! allSelected && someSelected ) { return null }
28- }
29-
30- return allSelected ;
15+ const getNodesAffectedBySelection = (
16+ allRows : DataViewTrTree [ ] ,
17+ node : DataViewTrTree ,
18+ isChecking : boolean ,
19+ isSelected ?: ( item : DataViewTrTree ) => boolean
20+ ) : DataViewTrTree [ ] => {
21+
22+ const getDescendants = ( node : DataViewTrTree ) : DataViewTrTree [ ] =>
23+ node . children ? node . children . flatMap ( getDescendants ) . concat ( node ) : [ node ] ;
24+
25+ const findParent = ( child : DataViewTrTree , rows : DataViewTrTree [ ] ) : DataViewTrTree | undefined =>
26+ rows . find ( row => row . children ?. some ( c => c === child ) ) ??
27+ rows . flatMap ( row => row . children ?? [ ] ) . map ( c => findParent ( child , [ c ] ) ) . find ( p => p ) ;
28+
29+ const getAncestors = ( node : DataViewTrTree ) : DataViewTrTree [ ] => {
30+ const ancestors : DataViewTrTree [ ] = [ ] ;
31+ let parent = findParent ( node , allRows ) ;
32+ while ( parent ) {
33+ ancestors . push ( parent ) ;
34+ parent = findParent ( parent , allRows ) ;
35+ }
36+ return ancestors ;
37+ } ;
38+
39+ const affectedNodes = new Set ( [ node , ...getDescendants ( node ) ] ) ;
40+
41+ getAncestors ( node ) . forEach ( ancestor => {
42+ const allChildrenSelected = ancestor . children ?. every ( child => isSelected ?.( child ) || affectedNodes . has ( child ) ) ;
43+ const anyChildAffected = ancestor . children ?. some ( child => affectedNodes . has ( child ) || child . id === node . id ) ;
44+
45+ if ( isChecking ? ! isSelected ?.( ancestor ) && allChildrenSelected : isSelected ?.( ancestor ) && anyChildAffected ) {
46+ affectedNodes . add ( ancestor ) ;
47+ }
48+ } ) ;
49+
50+ return Array . from ( affectedNodes ) ;
3151} ;
3252
53+
3354/** extends TableProps */
3455export interface DataViewTableTreeProps extends Omit < TableProps , 'onSelect' | 'rows' > {
3556 /** Columns definition */
@@ -83,7 +104,7 @@ export const DataViewTableTree: React.FC<DataViewTableTreeProps> = ({
83104 }
84105 const isExpanded = expandedNodeIds . includes ( node . id ) ;
85106 const isDetailsExpanded = expandedDetailsNodeNames . includes ( node . id ) ;
86- const isChecked = isSelected && isNodeChecked ( node , isSelected ) ;
107+ const isChecked = isSelected ?. ( node ) ;
87108 let icon = leafIcon ;
88109 if ( node . children ) {
89110 icon = isExpanded ? expandedIcon : collapsedIcon ;
@@ -100,7 +121,7 @@ export const DataViewTableTree: React.FC<DataViewTableTreeProps> = ({
100121 const otherDetailsExpandedNodeIds = prevDetailsExpanded . filter ( id => id !== node . id ) ;
101122 return isDetailsExpanded ? otherDetailsExpandedNodeIds : [ ...otherDetailsExpandedNodeIds , node . id ] ;
102123 } ) ,
103- onCheckChange : ( isSelectDisabled ?.( node ) || ! onSelect ) ? undefined : ( _event , isChecking ) => onSelect ?.( isChecking , getDescendants ( node ) ) ,
124+ onCheckChange : ( isSelectDisabled ?.( node ) || ! onSelect ) ? undefined : ( _event , isChecking ) => onSelect ?.( isChecking , getNodesAffectedBySelection ( rows , node , isChecking , isSelected ) ) ,
104125 rowIndex,
105126 props : {
106127 isExpanded,
0 commit comments