@@ -1714,6 +1714,14 @@ def _report_duration(
17141714 tbe_id = self .uuid ,
17151715 )
17161716
1717+ def _get_tensor_memory (self , tensor_name : str ) -> int :
1718+ """Get memory usage of a tensor in bytes."""
1719+ if not hasattr (self , tensor_name ):
1720+ self .log (f"Tensor '{ tensor_name } ' not found, using 0 bytes" )
1721+ return 0
1722+ tensor = getattr (self , tensor_name )
1723+ return tensor .numel () * tensor .element_size ()
1724+
17171725 @torch .jit .ignore
17181726 def _report_tbe_mem_usage (self ) -> None :
17191727 if self .stats_reporter is None :
@@ -1724,18 +1732,17 @@ def _report_tbe_mem_usage(self) -> None:
17241732 return
17251733
17261734 total_mem_usage = sum (
1727- param .numel () * param .element_size () for param in self .parameters ()
1728- ) + sum (buffer .numel () * buffer .element_size () for buffer in self .buffers ())
1735+ p .numel () * p .element_size () for p in self .parameters ()
1736+ ) + sum (b .numel () * b .element_size () for b in self .buffers ())
1737+
17291738 if self .use_cpu :
17301739 total_hbm_usage = 0
17311740 total_uvm_usage = total_mem_usage
17321741 else :
1733- # hbm usage is total usage minus uvm usage
17341742 total_uvm_usage = sum (
1735- getattr (self , tensor_name ).numel ()
1736- * getattr (self , tensor_name ).element_size ()
1737- for tensor_name in self ._uvm_tensors_log
1738- if hasattr (self , tensor_name )
1743+ self ._get_tensor_memory (name )
1744+ for name in self ._uvm_tensors_log
1745+ if hasattr (self , name )
17391746 )
17401747 total_hbm_usage = total_mem_usage - total_uvm_usage
17411748
0 commit comments