Skip to content

Conversation

@jithunnair-amd
Copy link
Collaborator

@jithunnair-amd jithunnair-amd commented Oct 27, 2025

Isalia20 and others added 30 commits October 20, 2025 18:56
Slightly faster cholesky, removed one redundant simdgroup_multiply
<img width="721" height="593" alt="Screenshot 2025-10-19 at 22 00 19" src="https://github.com/user-attachments/assets/e3a9005b-9347-4e62-a24d-16ba5e28849a" />

Generate benchmarks with(measured on M1 Pro):
```
import torch
import numpy as np
import time
import csv

matrix_sizes = [512, 1024, 2048, 4096]
batch_sizes = [1, 2, 4, 8, 16]
num_runs = 10
warmup_runs = 3

def create_spd_matrix(n, batch_size):
    torch.manual_seed(42)
    A = torch.randn(batch_size, n, n, dtype=torch.float32)
    return A @ A.transpose(-2, -1) + n * torch.eye(n).expand(batch_size, -1, -1)

def run_cholesky_mps(A):
    torch.mps.synchronize()
    start = time.perf_counter()
    b = torch.linalg.cholesky(A, upper=False)
    torch.mps.synchronize()
    end = time.perf_counter()
    return b, end - start

results = {
    'N': [],
    'batch_size': [],
    'mean_time': [],
    'std_time': []
}

for n in matrix_sizes:
    for batch_size in batch_sizes:
        print(f"\nBenchmarking N={n}, batch_size={batch_size}")

        try:
            A_cpu = create_spd_matrix(n, batch_size)
            A_mps = A_cpu.to("mps")

            for _ in range(warmup_runs):
                _, _ = run_cholesky_mps(A_mps)

            times = []
            for _ in range(num_runs):
                _, t = run_cholesky_mps(A_mps)
                times.append(t)

            mean_time = np.mean(times)
            std_time = np.std(times)

            results['N'].append(n)
            results['batch_size'].append(batch_size)
            results['mean_time'].append(mean_time)
            results['std_time'].append(std_time)

            print(f"Mean time: {mean_time:.4f}s ± {std_time:.4f}s")

        except RuntimeError as e:
            print(f"Error for N={n}, batch_size={batch_size}: {e}")
            continue

with open('cholesky_benchmark_times.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['N', 'batch_size', 'mean_time', 'std_time'])
    for i in range(len(results['N'])):
        writer.writerow([
            results['N'][i],
            results['batch_size'][i],
            results['mean_time'][i],
            results['std_time'][i]
        ])
```
Pull Request resolved: pytorch#165867
Approved by: https://github.com/malfet
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

Pull Request resolved: pytorch#162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd
Add logging for debugging annotation bugs. Log will show with `TORCH_LOGS="+annotation" `

Pull Request resolved: pytorch#165797
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/SherlockNoMad
This reverts commit 779296a.

Reverted pytorch#162053 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#162053 (comment)))
…rnels (pytorch#165815)

Scheduler relies on node.last_usage to free buffers. `last_usage` may contain a buffer that is allocated in previous graph partition AND not directly accessed in the current graph partition.

## Example
```python
def f(x):
    y = x + 1
    z = torch.ops.aten.view.dtype(y, torch.float8_e4m3fn)
    z_cpu = z.cpu()
    u_cuda = z_cpu.cuda()
    return u_cuda
```

In the generated code, we have
```
def partition_0(args):
    ...
    # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.add, aten.view]
    buf1 = torch.ops.aten.view.dtype(buf0, torch.float8_e4m3fn) # < ------ buf1 is a view of buf0
    buf2 = buf1 # <------- buf2 is buf1
    assert_size_stride(buf2, (8, ), (1, ), 'torch.ops.aten.view.dtype')
    assert_alignment(buf2, 16, 'torch.ops.aten.view.dtype')
    return (buf2, )

def call(self, args):
    ...
    (buf2,) = self.partitions[0](partition0_args)
    ...
    buf3.copy_(buf2, False)
    del buf0
    del buf1
    del buf2  # <---- `del buf2` leads to `del buf0`. BUT `buf0` is not returned from partition_0.
    ...
```

Note: view is treated as a fallback kernel due to its special dtype.
https://github.com/pytorch/pytorch/blob/de09bab4b66002a8a9a2195f50f96a78868a3d39/torch/_inductor/lowering.py#L841-L843

## Fix

This PR fixes the issue by also returning these buffers to be freed later.

Pull Request resolved: pytorch#165815
Approved by: https://github.com/eellison
… backend_config) (pytorch#165433)

Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/~
- torch/ao/quantization/quantizer/
- torch/ao/quantization/backend_config/

fix partialy pytorch#164878

Pull Request resolved: pytorch#165433
Approved by: https://github.com/albanD
fix for pytorch#165624 - we were applying pre pass multiple times.

Pull Request resolved: pytorch#165917
Approved by: https://github.com/bdhirsh
Replaces 78 assert statements across 10 files in torch.autograd with explicit if-checks raising AssertionError to prevent assertions from being disabled with Python -O flag. This ensures error checking remains active in optimized builds.

fix partially pytorch#164878

Pull Request resolved: pytorch#165627
Approved by: https://github.com/albanD
…d torch/ao/quantization/pt2e/* (pytorch#165317)

Replace assert statements with explicit if/raise patterns in:
- torch/ao/quantization/experimental/* (11 errors)
- torch/ao/quantization/pt2e/* (68 errors)

fix partialy pytorch#164878
Pull Request resolved: pytorch#165317
Approved by: https://github.com/albanD
…h#165750)

This series of changes try to cover C style casts into C++ alternatives.

Pull Request resolved: pytorch#165750
Approved by: https://github.com/Skylion007
…h#165410)

Including:
- `torch/utils/*.py`

Fixes part of pytorch#164878

Pull Request resolved: pytorch#165410
Approved by: https://github.com/albanD
Add final stage of aot_stage2_compile for autograd and inference.

Differential Revision: D84844699

Pull Request resolved: pytorch#165668
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
This should allow us to move gfx1100 workflow to a lower frequency and also allow it to be triggered on PRs via a dedicated label, for any PRs that target Navi fixes such as [this](pytorch#165630) or [this](pytorch#165625).

Pull Request resolved: pytorch#165699
Approved by: https://github.com/jeffdaily
Typical warp shuffle reduction has the following pattern:
<img width="1138" height="501" alt="image" src="https://github.com/user-attachments/assets/3bd176dc-0ad2-4df6-90c7-06e467337166" />

which is exhibited in Triton generated by torch.compile:
<img width="663" height="403" alt="image" src="https://github.com/user-attachments/assets/7f9f36cd-b9eb-44c1-879e-b469668a2ea8" />

Switch the warp shuffle order to make bitwise equivalence between the 2 easier.
PTX difference between old and new, we see a few extra instructions: https://www.diffchecker.com/h6ly3INC/

Comparing the performance on different reduction operations, we see minimal differences. New represents the changes in this PR, old represents the past warp shuffle order:
```
Tensor Shape              Operation            New all dims (ms)       New dim=0 (ms)      New dim=1 (ms)     Old all dims (ms)    Old dim=0 (ms)      Old dim=1 (ms)
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1024, 1024)              mean                 0.015817             0.016259             0.013642             0.015990             0.016258             0.013631
(1024, 1024)              sum                  0.015917             0.015906             0.013359             0.015707             0.016266             0.013226
(1024, 1024)              min                  0.016021             0.024625             0.015631             0.015761             0.024485             0.015317
(1024, 1024)              max                  0.016349             0.024971             0.015972             0.015771             0.025001             0.015314
(1024, 1024)              argmin               0.018070             0.024448             0.015578             0.018135             0.025370             0.015322
(1024, 1024)              argmax               0.018427             0.024859             0.015932             0.018164             0.024452             0.015639
(1024, 1024)              var                  0.020078             0.026413             0.020295             0.020199             0.026381             0.020214
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(2048, 2048)              mean                 0.023826             0.023726             0.022273             0.023236             0.023776             0.022248
(2048, 2048)              sum                  0.023840             0.023355             0.021974             0.023294             0.023354             0.021884
(2048, 2048)              min                  0.024519             0.041263             0.024620             0.023292             0.041491             0.024358
(2048, 2048)              max                  0.024509             0.041670             0.024277             0.023334             0.041231             0.024395
(2048, 2048)              argmin               0.026125             0.041282             0.024567             0.026772             0.041773             0.024296
(2048, 2048)              argmax               0.026117             0.041487             0.024572             0.026412             0.041477             0.024273
(2048, 2048)              var                  0.026603             0.048581             0.031308             0.027587             0.048603             0.030860
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(4096, 4096)              mean                 0.053927             0.057070             0.054073             0.053028             0.057544             0.053935
(4096, 4096)              sum                  0.053604             0.057410             0.054451             0.053076             0.057033             0.054266
(4096, 4096)              min                  0.054293             0.109122             0.058363             0.053821             0.108689             0.058382
(4096, 4096)              max                  0.054258             0.108035             0.058703             0.053492             0.110552             0.058376
(4096, 4096)              argmin               0.056805             0.111167             0.058301             0.056836             0.112325             0.058292
(4096, 4096)              argmax               0.056488             0.110958             0.058636             0.056844             0.111000             0.057928
(4096, 4096)              var                  0.058936             0.141755             0.068693             0.059735             0.141284             0.068500
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(8192, 8192)              mean                 0.145552             0.148082             0.138647             0.145364             0.147818             0.138207
(8192, 8192)              sum                  0.145985             0.147900             0.138714             0.145755             0.148031             0.138616
(8192, 8192)              min                  0.146566             0.205359             0.192739             0.145611             0.205237             0.182335
(8192, 8192)              max                  0.146526             0.204844             0.193050             0.146073             0.205457             0.182697
(8192, 8192)              argmin               0.150190             0.206605             0.192543             0.150654             0.206847             0.182007
(8192, 8192)              argmax               0.150481             0.206368             0.192535             0.150845             0.206430             0.182022
(8192, 8192)              var                  0.150884             0.184546             0.203900             0.151594             0.184172             0.197983
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1, 1024, 128)            mean                 0.014293             0.008119             0.014533             0.013861             0.008022             0.014449
(1, 1024, 128)            sum                  0.014039             0.007877             0.014111             0.014219             0.008227             0.014045
(1, 1024, 128)            min                  0.014159             0.011354             0.023493             0.014271             0.010862             0.023644
(1, 1024, 128)            max                  0.014154             0.011027             0.023368             0.014259             0.011234             0.023692
(1, 1024, 128)            argmin               0.016403             0.005677             0.023328             0.016273             0.005683             0.024073
(1, 1024, 128)            argmax               0.016734             0.005675             0.023437             0.016580             0.005318             0.023331
(1, 1024, 128)            var                  0.018338             0.009549             0.025538             0.018528             0.009391             0.024777
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(5, 1024, 128)            mean                 0.014873             0.010131             0.015546             0.015123             0.010131             0.015481
(5, 1024, 128)            sum                  0.015334             0.009673             0.015824             0.014736             0.009671             0.015438
(5, 1024, 128)            min                  0.015047             0.013252             0.024573             0.014803             0.013163             0.024551
(5, 1024, 128)            max                  0.015050             0.013339             0.024197             0.014810             0.013525             0.024230
(5, 1024, 128)            argmin               0.017341             0.012737             0.024306             0.017471             0.012379             0.024991
(5, 1024, 128)            argmax               0.017345             0.012411             0.024421             0.017422             0.012471             0.024237
(5, 1024, 128)            var                  0.019973             0.011453             0.026188             0.020050             0.011438             0.026282
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10, 1024, 128)           mean                 0.016976             0.011575             0.016831             0.016722             0.011927             0.017173
(10, 1024, 128)           sum                  0.017039             0.011841             0.017159             0.016385             0.011860             0.016753
(10, 1024, 128)           min                  0.017036             0.015331             0.026770             0.016944             0.015205             0.027166
(10, 1024, 128)           max                  0.017369             0.015348             0.027077             0.016531             0.015716             0.026819
(10, 1024, 128)           argmin               0.019203             0.014447             0.026813             0.018994             0.014497             0.027313
(10, 1024, 128)           argmax               0.019563             0.014795             0.027140             0.019460             0.014912             0.026733
(10, 1024, 128)           var                  0.020529             0.014316             0.030405             0.020719             0.013960             0.029964
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100, 1024, 128)          mean                 0.045046             0.039168             0.046082             0.044839             0.039217             0.045782
(100, 1024, 128)          sum                  0.045094             0.039150             0.045777             0.044496             0.039542             0.046083
(100, 1024, 128)          min                  0.045768             0.054466             0.076244             0.044915             0.053943             0.076599
(100, 1024, 128)          max                  0.045748             0.054459             0.076188             0.044931             0.053949             0.076856
(100, 1024, 128)          argmin               0.048275             0.054046             0.076647             0.048694             0.054105             0.077004
(100, 1024, 128)          argmax               0.048267             0.054395             0.077401             0.048691             0.054131             0.076751
(100, 1024, 128)          var                  0.049710             0.043254             0.083077             0.050971             0.043251             0.082378
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1000, 1000, 100)         mean                 0.202312             0.196723             0.197765             0.201774             0.196641             0.197459
(1000, 1000, 100)         sum                  0.202651             0.196682             0.197736             0.202175             0.196313             0.197523
(1000, 1000, 100)         min                  0.203022             0.264762             0.269200             0.202729             0.264129             0.268694
(1000, 1000, 100)         max                  0.202864             0.264396             0.269388             0.202486             0.263896             0.268720
(1000, 1000, 100)         argmin               0.226727             0.263781             0.268651             0.226597             0.264676             0.268983
(1000, 1000, 100)         argmax               0.226412             0.264469             0.269090             0.226570             0.264595             0.269178
(1000, 1000, 100)         var                  0.243223             0.204079             0.216096             0.241942             0.204079             0.215925
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(10000, 100)              mean                 0.016193             0.020277             0.014316             0.016152             0.020324             0.013712
(10000, 100)              sum                  0.016289             0.020237             0.014034             0.016168             0.020265             0.013708
(10000, 100)              min                  0.016046             0.030872             0.019609             0.016208             0.030867             0.018627
(10000, 100)              max                  0.016369             0.030835             0.019257             0.016218             0.030861             0.018209
(10000, 100)              argmin               0.017957             0.031171             0.019517             0.018050             0.031556             0.018077
(10000, 100)              argmax               0.017961             0.031658             0.019521             0.018060             0.031564             0.018087
(10000, 100)              var                  0.020393             0.035652             0.019339             0.020144             0.035987             0.019171
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(100000, 10)              mean                 0.015718             0.016576             0.016555             0.015999             0.016246             0.014869
(100000, 10)              sum                  0.015833             0.016247             0.016572             0.016007             0.016627             0.014872
(100000, 10)              min                  0.015888             0.020510             0.023920             0.015671             0.020821             0.021417
(100000, 10)              max                  0.015889             0.020479             0.023918             0.016077             0.020386             0.021421
(100000, 10)              argmin               0.018233             0.020863             0.023647             0.017574             0.020864             0.021103
(100000, 10)              argmax               0.017896             0.020527             0.023296             0.017569             0.020447             0.021098
(100000, 10)              var                  0.020005             0.024198             0.024372             0.020075             0.024167             0.022415
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 1023)        mean                 1.874816             1.963506             1.903909             1.873279             1.963859             1.903230
(1023, 1023, 1023)        sum                  1.875030             1.965716             1.902458             1.873566             1.960730             1.901642
(1023, 1023, 1023)        min                  1.878563             2.473455             2.179092             1.875174             2.482086             2.183027
(1023, 1023, 1023)        max                  1.879128             2.474803             2.178895             1.874831             2.482253             2.183884
(1023, 1023, 1023)        argmin               1.921800             2.476629             2.174831             1.923987             2.472641             2.170453
(1023, 1023, 1023)        argmax               1.922605             2.476688             2.177927             1.923366             2.472808             2.172979
(1023, 1023, 1023)        var                  1.972606             3.088695             2.758797             1.978679             3.095658             2.762243
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 255)         mean                 0.489984             0.500954             0.492957             0.489891             0.500654             0.491971
(1023, 1023, 255)         sum                  0.490228             0.500764             0.492289             0.489624             0.501089             0.492824
(1023, 1023, 255)         min                  0.491457             0.563560             0.553334             0.490355             0.564709             0.554754
(1023, 1023, 255)         max                  0.491396             0.563628             0.553345             0.490017             0.565004             0.554947
(1023, 1023, 255)         argmin               0.503666             0.561512             0.551831             0.503845             0.560972             0.551017
(1023, 1023, 255)         argmax               0.503602             0.561185             0.551407             0.504328             0.561267             0.551448
(1023, 1023, 255)         var                  0.510844             0.709452             0.701630             0.512693             0.710365             0.701965
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
(1023, 1023, 377)         mean                 0.707439             0.727646             0.712019             0.706769             0.727101             0.711632
(1023, 1023, 377)         sum                  0.707780             0.727453             0.711554             0.706807             0.726656             0.711729
(1023, 1023, 377)         min                  0.709423             0.819809             0.794379             0.707847             0.822086             0.796664
(1023, 1023, 377)         max                  0.709297             0.819780             0.794308             0.707566             0.821913             0.796690
(1023, 1023, 377)         argmin               0.725028             0.817088             0.791695             0.726039             0.816445             0.790828
(1023, 1023, 377)         argmax               0.725301             0.817011             0.791420             0.726040             0.816917             0.791143
(1023, 1023, 377)         var                  0.740859             1.034165             1.006712             0.743413             1.035506             1.007638
```

Differential Revision: [D85022826](https://our.internmc.facebook.com/intern/diff/D85022826)
Pull Request resolved: pytorch#164790
Approved by: https://github.com/ngimel, https://github.com/eqy
Fix pytorch#159251

Add an optional argument `return_outputs` to the schedule `step`

Pull Request resolved: pytorch#165822
Approved by: https://github.com/wconstab
…h#164882)

Fixes pytorch#163343.

After some consideration, I propose we remove the anonymous namespace around from/to in favor of:
1. Adding inline to the function implementations, assuming that they will not change in the near future
2. If we decide to change them, we will wrap the code in inline versioned namespaces such that the implementations within any versioned namespace will be guaranteed identical.

Note that:
- We eventually intend to abstract away usage of `from`/`to` (related: @lw's TORCH_BOX work)
- The from/to implementations are now powered through class template specializations, where adding a specialization does not change the from/to signatures.

I do plan to deprecate top-level from/to in favor of torch::stable::details::from/to consequently. This way we can stop polluting the global namespace.

Pull Request resolved: pytorch#164882
Approved by: https://github.com/lw, https://github.com/albanD
Fixes pytorch#165911

- Add message to Attribute error so we see `  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])` instead of just `Developer debug context: raised exception AttributeError([])`
- Add stack trace in `ObservedException` so we display the inner most error stack trace back to user code

Output:

```
/data/users/shangdiy/pytorch/torch/__init__.py:2641: UserWarning: You are calling torch.compile inside torch.export region. To capture an useful graph, we will implicitly switch to torch.compile(backend=eager)
  warnings.warn(
Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1385, in var_getattr
    subobj = self._getattr_static(name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/variables/user_defined.py", line 1256, in _getattr_static
    subobj = type(self.value).__getattribute__(self.value, name)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'Linear' object has no attribute 'w'

During handling of the above exception, another exception occurred:

torch._dynamo.exc.ObservedAttributeError: 'Linear' object has no attribute 'w'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/users/shangdiy/pytorch/test.py", line 34, in <module>
    mod = torch._dynamo.functional_export._dynamo_graph_capture_for_export(Model())(x)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 481, in inner
    out = fullgraph_capture(
          ^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1053, in fullgraph_capture
    return _fullgraph_capture_frame(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/shangdiy/pytorch/torch/_dynamo/convert_frame.py", line 1115, in _fullgraph_capture_frame
    raise e.with_traceback(None) from e.__cause__  # User compiler error
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.Unsupported: Observed exception
  Explanation: Dynamo found no exception handler at the top-level compiled function when encountering an exception. Exception will propagate outside the compiled region.
  Hint: Dynamo has detected that tracing the code will result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled.
  Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues.

  Developer debug context: raised exception AttributeError(["'Linear' object has no attribute 'w'"])

 For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html

from user code:
   File "/data/users/shangdiy/pytorch/torch/_dynamo/functional_export.py", line 171, in forward
    res = self._export_root(*args, **kwargs)
  File "/data/users/shangdiy/pytorch/test.py", line 31, in forward
    weight = self.linear.w

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: pytorch#165930
Approved by: https://github.com/anijain2305
… CI (pytorch#165938)

In recent change LocalTensor introduced dependency on Numpy and has broken Torchao CI.
This dependency cna be made optional and required only when Local Tensor is used.

Pull Request resolved: pytorch#165938
Approved by: https://github.com/atalman
…ytorch#164625)

Pybind's API entails a small unnecessary overhead when working with args. (Similarly, we should probably be using vectorcall, but that's a bigger change for both us and pybind11.)

Pull Request resolved: pytorch#164625
Approved by: https://github.com/albanD
ghstack dependencies: pytorch#164624
…ytorch#165787)

We moved the method to get root mesh into class in pytorch#164510. This is to further clean code up.

Differential Revision: [D85090191](https://our.internmc.facebook.com/intern/diff/D85090191)
Pull Request resolved: pytorch#165787
Approved by: https://github.com/fegin
Original PR that did this was reverted due to merge conflicts.

Trying it again

Pull Request resolved: pytorch#165918
Approved by: https://github.com/oulgen
To not pollute the global namespace, we should move the `from`/`to` APIs into torch::stable::detail. We are also following our normal deprecation cycle and choosing to continue exposing the global `from`/`to` for the time being as people who onboard their extensions onto 2.9 would not be able to build with 2.10 otherwise.

Note that this means that within libtorch, we do not get the luxury of tacking on a `using torch::stable::detail::from` because then it leads to build time ambiguous calls --> both the global and namespace APIs are exposed, which one do I want? So that is why you see every local site is updated.

Note that the update is _not_ necessary from a custom op writer point of view. FA3 can continue to build on torch nightlies without changing any code. (Since this is a header change, this PR has no implication on runtime, a previously built FA3 ABI stable wheel will continue to work fine with newer torch versions after this PR.)

Once TORCH_BOX lands, we would be free to remove these global APIs when the deprecation cycle is up (April 2026) and encourage people to use TORCH_BOX and avoid from/to entirely.

Pull Request resolved: pytorch#164956
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#164882
…than 7.2 (pytorch#165789)

The `-amdgpu-coerce-illegal-types=1` flag is for LLVM that is in ROCm 6.3, 6.4, 7.0, and 7.1. It will not be in ROCm7.2. It was added to enable performance improvements for composable kernel. ROCm7.2 and newer changed the compiler so that the flag isn't needed to achieve those performance improvements. Keeping the flag with ROCm 7.2 breaks the PyTorch build.

Pull Request resolved: pytorch#165789
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily
This PR enables `PLW0127` in ruff, which checks self-assignment of variables with the form `var=var`.

Pull Request resolved: pytorch#165851
Approved by: https://github.com/Lucaskabela
…#165481)

* Moving rocm.yml from using persistent non-ARC runners from the combined MI2xx (MI210 + MI250) cluster to the ARC runners from the MI250 cluster. This halves the number of nodes, but provides access to approximately 4 times the runners, since every 8-GPU MI250 node now provides 8 1-GPU runners. This should help with concurrent capacity and queueing on the MI2xx jobs.

Tested here successfully: https://github.com/pytorch/pytorch/actions/runs/18620814622/job/53092469720

Pull Request resolved: pytorch#165481
Approved by: https://github.com/jeffdaily

Co-authored-by: Jithun Nair <[email protected]>
… build (pytorch#165708)

Audit: To prevent future issues with functools.partial or callable objects.

Pull Request resolved: pytorch#165708
Approved by: https://github.com/Lucaskabela
Add a new 'reduction' tag to tags.yaml and apply it to 98 reduction
operator variants across 21 operator families (sum, mean, min, max,
argmin, argmax, amin, amax, aminmax, prod, all, any, norm, var, std,
std_mean, var_mean, nansum, logsumexp, count_nonzero, linalg_vector_norm).

This tag categorizes operators that perform reduction operations,
computing aggregate values across one or more dimensions of input
tensor(s).

Based on PR pytorch#153342 - co-written with @AlonSardas.

Just as we have pointwise tag - this can be useful for compiler passes, or for opting into sharding rules.

Pull Request resolved: pytorch#165155
Approved by: https://github.com/ezyang, https://github.com/zou3519, https://github.com/mlazos
anijain2305 and others added 17 commits October 27, 2025 16:47
pytorch#165707)

Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: pytorch#165707
Approved by: https://github.com/Lucaskabela
ghstack dependencies: pytorch#166251
This reverts commit e67e3d9.

Reverted pytorch#161370 on behalf of https://github.com/atalman due to Sorry this is failing libtorch nightly builds [pytorch/pytorch/actions/runs/18800131287/job/53653414136](https://github.com/pytorch/pytorch/actions/runs/18800131287/job/53653414136) ([comment](pytorch#161370 (comment)))
Summary:
Attempting to forward fix failures from D85405167 (PR
pytorch#166021)

This is devmates suggestion and seems to work, but idk if it's a good idea or not.  Devmate says it's getting resolved to at::min which is host only, and it doesn't happen in OSS is likely because `AT_PER_OPERATOR_HEADERS` is defined in OSS but not internally.

```
In file included from .../ATen/native/hip/Normalization.hip:11:
.../ATen/native/hip/Normalization.cuh:302:37: error: no matching function for call to 'min'
  302 |         v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
      |                                     ^~~
```

Differential Revision: D85463674

Pull Request resolved: pytorch#166195
Approved by: https://github.com/Camyll, https://github.com/malfet, https://github.com/eqy
The MAX_NUM_ARGS of ComboKernel is currently a fixed number. We need to tune this number to avoid large fusion for MTIA, thus making it configurable.

Differential Revision: [D85509352](https://our.internmc.facebook.com/intern/diff/D85509352/)

Pull Request resolved: pytorch#166274
Approved by: https://github.com/eellison
A few workspace API changes:
1. return outer name when creating. Usually a use case does not care about outer name. But for mix-order-reduction (stacked PR), we need it to do the next-layer of reduction on the workspace tensor
2. be able to override workspace tensor dtype
3. be able to delay the deallocation of workspace tensors in TritonKernel.call_kernel since they may be used after the call. The lifetime of the workspace tensors are only enlarged a little bit. They would be deallocated once the next layer reduction is done.

Test with the stacked PR.

Pull Request resolved: pytorch#166204
Approved by: https://github.com/jansel
Differential Revision: D85446553

Internal builds failing after pytorch#161369

```
buck-headers/ATen/Context.h:22:10: fatal error: 'ATen/detail/XLAHooksInterface.h' file not found
   22 | #include <ATen/detail/XLAHooksInterface.h>
      |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
```

Changes similar to that PR also change the build_variables file, which I've done here.  I'm not sure why this wasn't caught by the bazel build we have?

Sanity checked that some of the previously failing builds pass after this change
Pull Request resolved: pytorch#166179
Approved by: https://github.com/Camyll
Silences existing errors on main to keep errors and noise from the type checker to a minimum

Pull Request resolved: pytorch#166312
Approved by: https://github.com/Skylion007
)

Summary:
Conversion from/to float16 was not getting covered by conversion templates, because these used float16_t as data type instead of the custom at::Half.

We are adding a shim that makes conversion routines use autovec code for float16

We observed the following performance improvements when compiling targeting armv9-a+sve2+fp16

before:

float16_t->uint8->float16_t ===> 657.489us
float16_t->int8->float16_t ===> 656.518us
float16_t->int16->float16_t ===> 668.998us
float16_t->int64->float16_t ===> 618.444us
float16_t->double->float16_t ===> 439.728us

after

float16_t->uint8->float16_t ===> 181.216us  ----> 263% higher throughput
float16_t->int8->float16_t ===> 179.821us  -----> 265% higher throughput
float16_t->int16->float16_t ===> 183.417us  ----> 265% higher throughput
float16_t->int64->float16_t ===> 459.897us  ----> 35% higher throughput
float16_t->double->float16_t ===> 351.276us  ---> 25% higher throughput

Test Plan:
Correctness:

buck2 test mode/opt //caffe2/test:test_ops
buck2 test mode/opt //caffe2/test:torch

Performance:

buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test

Differential Revision: D85533271

Pull Request resolved: pytorch#166306
Approved by: https://github.com/mcfi, https://github.com/ezyang
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: pytorch#165931
Approved by: https://github.com/zhxchen17
- Fixes `s/#pragma onces/#pragma once` typoe

All methods in the headers must be inline, otherwise one gets barrage of following warnings
```
/Users/malfet/git/pytorch/pytorch/c10/metal/utils.h:337:7: warning: unused function 'conj<half __attribute__((ext_vector_type(2)))>' [-Wunused-function]
half2 conj(half2 a) {
      ^
/Users/malfet/git/pytorch/pytorch/c10/metal/utils.h:342:8: warning: unused function 'conj<float __attribute__((ext_vector_type(2)))>' [-Wunused-function]
float2 conj(float2 a) {
       ^
2 warnings generated.
```
Pull Request resolved: pytorch#166315
Approved by: https://github.com/seemethere, https://github.com/atalman
A few internal jobs are observing very high guard overhead for DTensor.
Since we own DTensor, we can make those guards way faster.

Pull Request resolved: pytorch#165824
Approved by: https://github.com/Lucaskabela, https://github.com/bdhirsh
…sting_IFU_2025-10-27

# Conflicts:
#	.ci/docker/build.sh
#	.ci/docker/ci_commit_pins/triton.txt
#	.ci/docker/libtorch/build.sh
#	CMakeLists.txt
#	aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
#	requirements-build.txt
#	test/dynamo/test_structured_trace.py
#	test/inductor/test_cuda_repro.py
#	test/inductor/test_decompose_mem_bound_mm.py
#	test/inductor/test_max_autotune.py
#	test/test_linalg.py
#	test/test_matmul_cuda.py
#	torch/_inductor/runtime/coordinate_descent_tuner.py
#	torch/_inductor/runtime/triton_heuristics.py
#	torch/testing/_internal/common_utils.py
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Oct 27, 2025

Jenkins build for 35615a5d10e8df9160f162cdf11f6e67432a1eee commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

Detected error during base docker image building:

#61 17.94 + sudo -E -H -u jenkins env -u SUDO_UID -u SUDO_GID -u SUDO_COMMAND -u SUDO_USER env PATH=/opt/rocm/llvm/bin:/opt/rocm/opencl/bin:/opt/rocm/hip/bin:/opt/rocm/hcc/bin:/opt/rocm/bin:/opt/conda/envs/py_3.12/bin:/opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin LD_LIBRARY_PATH= git clone --recursive https://github.com/ROCm/triton triton
#61 17.95 Cloning into 'triton'...
#61 29.57 + cd triton
#61 29.57 + as_jenkins git checkout '<<<<<<<' HEAD d704bc6e69c1a588c8edd3cbb67505d554ed65f6 ======= 7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd '>>>>>>>' upstream/main
#61 29.57 + sudo -E -H -u jenkins env -u SUDO_UID -u SUDO_GID -u SUDO_COMMAND -u SUDO_USER env PATH=/opt/rocm/llvm/bin:/opt/rocm/opencl/bin:/opt/rocm/hip/bin:/opt/rocm/hcc/bin:/opt/rocm/bin:/opt/conda/envs/py_3.12/bin:/opt/conda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin LD_LIBRARY_PATH= git checkout '<<<<<<<' HEAD d704bc6e69c1a588c8edd3cbb67505d554ed65f6 ======= 7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd '>>>>>>>' upstream/main
#61 29.58 error: pathspec '<<<<<<<' did not match any file(s) known to git
#61 29.58 error: pathspec 'HEAD' did not match any file(s) known to git
#61 29.58 error: pathspec 'd704bc6e69c1a588c8edd3cbb67505d554ed65f6' did not match any file(s) known to git
#61 29.58 error: pathspec '=======' did not match any file(s) known to git
#61 29.58 error: pathspec '7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd' did not match any file(s) known to git
#61 29.58 error: pathspec '>>>>>>>' did not match any file(s) known to git

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Oct 28, 2025

Jenkins build for 8f578a1057dd510b6145b9b0df8425cb0e42c091 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pragupta pragupta force-pushed the rocm7.1_internal_testing_IFU_2025-10-27 branch from 8f578a1 to 18f870f Compare October 28, 2025 22:17
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Oct 28, 2025

Jenkins build for 18f870f55b6a1b6399cc9febc4d1b1a93131ac04 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Oct 28, 2025

Jenkins build for 18f870f55b6a1b6399cc9febc4d1b1a93131ac04 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Oct 29, 2025

Jenkins build for 18f870f55b6a1b6399cc9febc4d1b1a93131ac04 commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pragupta
Copy link
Collaborator

Ignore this PR as it has merge conflict, closing this one in preference to #2769

@pragupta pragupta closed this Oct 29, 2025
@pragupta pragupta deleted the rocm7.1_internal_testing_IFU_2025-10-27 branch October 29, 2025 20:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.