@@ -1214,3 +1214,127 @@ SELECT id, details FROM jobs AS j INNER JOIN cte1 ON id = job_id WHERE id = 1;
12141214	require .Equal (t , 1 , id )
12151215	require .Equal (t , 1 , details )
12161216}
1217+ 
1218+ // TestTopLevelQueryStats verifies that top-level query stats are collected 
1219+ // correctly, including when the query executes "plans inside plans". 
1220+ func  TestTopLevelQueryStats (t  * testing.T ) {
1221+ 	defer  leaktest .AfterTest (t )()
1222+ 	defer  log .Scope (t ).Close (t )
1223+ 
1224+ 	// testQuery will be updated throughout the test to the current target. 
1225+ 	var  testQuery  atomic.Value 
1226+ 	// The callback will send number of rows read and rows written (for each 
1227+ 	// ProducerMetadata.Metrics object) on these channels, respectively. 
1228+ 	rowsReadCh , rowsWrittenCh  :=  make (chan  int64 ), make (chan  int64 )
1229+ 	s , sqlDB , _  :=  serverutils .StartServer (t , base.TestServerArgs {
1230+ 		Knobs : base.TestingKnobs {
1231+ 			SQLExecutor : & ExecutorTestingKnobs {
1232+ 				DistSQLReceiverPushCallbackFactory : func (_  context.Context , query  string ) func (rowenc.EncDatumRow , coldata.Batch , * execinfrapb.ProducerMetadata ) (rowenc.EncDatumRow , coldata.Batch , * execinfrapb.ProducerMetadata ) {
1233+ 					if  target  :=  testQuery .Load (); target  ==  nil  ||  target .(string ) !=  query  {
1234+ 						return  nil 
1235+ 					}
1236+ 					return  func (row  rowenc.EncDatumRow , batch  coldata.Batch , meta  * execinfrapb.ProducerMetadata ) (rowenc.EncDatumRow , coldata.Batch , * execinfrapb.ProducerMetadata ) {
1237+ 						if  meta  !=  nil  &&  meta .Metrics  !=  nil  {
1238+ 							rowsReadCh  <-  meta .Metrics .RowsRead 
1239+ 							rowsWrittenCh  <-  meta .Metrics .RowsWritten 
1240+ 						}
1241+ 						return  row , batch , meta 
1242+ 					}
1243+ 				},
1244+ 			},
1245+ 		},
1246+ 	})
1247+ 	defer  s .Stopper ().Stop (context .Background ())
1248+ 
1249+ 	if  _ , err  :=  sqlDB .Exec (` 
1250+ CREATE TABLE t (k INT PRIMARY KEY); 
1251+ INSERT INTO t SELECT generate_series(1, 10); 
1252+ CREATE FUNCTION no_reads() RETURNS INT AS 'SELECT 1' LANGUAGE SQL; 
1253+ CREATE FUNCTION reads() RETURNS INT AS 'SELECT count(*) FROM t' LANGUAGE SQL; 
1254+ CREATE FUNCTION write(x INT) RETURNS INT AS 'INSERT INTO t VALUES (x); SELECT x' LANGUAGE SQL; 
1255+ ` ); err  !=  nil  {
1256+ 		t .Fatal (err )
1257+ 	}
1258+ 
1259+ 	for  _ , tc  :=  range  []struct  {
1260+ 		name            string 
1261+ 		query           string 
1262+ 		expRowsRead     int64 
1263+ 		expRowsWritten  int64 
1264+ 	}{
1265+ 		{
1266+ 			name :           "simple read" ,
1267+ 			query :          "SELECT k FROM t" ,
1268+ 			expRowsRead :    10 ,
1269+ 			expRowsWritten : 0 ,
1270+ 		},
1271+ 		{
1272+ 			name :           "simple write" ,
1273+ 			query :          "INSERT INTO t SELECT generate_series(11, 42)" ,
1274+ 			expRowsRead :    0 ,
1275+ 			expRowsWritten : 32 ,
1276+ 		},
1277+ 		{
1278+ 			name : "read with apply join" ,
1279+ 			query : `SELECT ( 
1280+     WITH foo AS MATERIALIZED (SELECT k FROM t AS x WHERE x.k = y.k) 
1281+     SELECT * FROM foo 
1282+   ) FROM t AS y` ,
1283+ 			expRowsRead :    84 , // scanning the table twice 
1284+ 			expRowsWritten : 0 ,
1285+ 		},
1286+ 		{
1287+ 			name :           "routine, no reads" ,
1288+ 			query :          "SELECT no_reads()" ,
1289+ 			expRowsRead :    0 ,
1290+ 			expRowsWritten : 0 ,
1291+ 		},
1292+ 		{
1293+ 			name :           "routine, reads" ,
1294+ 			query :          "SELECT reads()" ,
1295+ 			expRowsRead :    42 ,
1296+ 			expRowsWritten : 0 ,
1297+ 		},
1298+ 		{
1299+ 			name :           "routine, write" ,
1300+ 			query :          "SELECT write(43)" ,
1301+ 			expRowsRead :    0 ,
1302+ 			expRowsWritten : 1 ,
1303+ 		},
1304+ 		{
1305+ 			name :           "routine, multiple reads and writes" ,
1306+ 			query :          "SELECT reads(), write(44), reads(), write(45), write(46), reads()" ,
1307+ 			expRowsRead :    133 , // first read is 43 rows, second is 44, third is 46 
1308+ 			expRowsWritten : 3 ,
1309+ 		},
1310+ 	} {
1311+ 		t .Run (tc .name , func (t  * testing.T ) {
1312+ 			testQuery .Store (tc .query )
1313+ 			errCh  :=  make (chan  error )
1314+ 			// Spin up the worker goroutine which will actually execute the 
1315+ 			// query. 
1316+ 			go  func () {
1317+ 				defer  close (errCh )
1318+ 				_ , err  :=  sqlDB .Exec (tc .query )
1319+ 				errCh  <-  err 
1320+ 			}()
1321+ 			// In the main goroutine, loop until the query is completed while 
1322+ 			// accumulating the top-level query stats. 
1323+ 			var  rowsRead , rowsWritten  int64 
1324+ 		LOOP:
1325+ 			for  {
1326+ 				select  {
1327+ 				case  read  :=  <- rowsReadCh :
1328+ 					rowsRead  +=  read 
1329+ 				case  written  :=  <- rowsWrittenCh :
1330+ 					rowsWritten  +=  written 
1331+ 				case  err  :=  <- errCh :
1332+ 					require .NoError (t , err )
1333+ 					break  LOOP
1334+ 				}
1335+ 			}
1336+ 			require .Equal (t , tc .expRowsRead , rowsRead )
1337+ 			require .Equal (t , tc .expRowsWritten , rowsWritten )
1338+ 		})
1339+ 	}
1340+ }
0 commit comments