Skip to content

[matmul] Last dimension of first input with shape ... must match second to last dimension of second input with shape ... #5

@ckuethe

Description

@ckuethe

mlx-community/GLM-Z1-32B-0414-4bit, b1006, rocm 7.12, gfx1151

=== INFERENCE PIPELINE DIAGNOSTICS ===
Loading model: models--mlx-community--GLM-Z1-32B-0414-4bit/snapshots/c5ae15e13a9bada45e9f8c94631a79229b5b2ebf
Model loaded.

--- TEST 1: Basic GPU ops ---
[DIAG] matmul(ones, 2*ones) expect=8	   shape=(4,4) min=8.000000 max=8.000000 mean=8.000000 |mean|=8.000000
[VALS] matmul result: [8.0000, 8.0000, 8.0000, 8.0000, 8.0000, 8.0000, 8.0000, 8.0000]
[hipBLASLt] first call
[hipBLASLt] M=4 N=4 K=4 ta=0 tb=0 lda=4 ldb=4 ldc=4
[DIAG] bf16 matmul expect=8		   shape=(4,4) min=8.000000 max=8.000000 mean=8.000000 |mean|=8.000000

--- TEST 2: quantized_matmul vs dequant ---
[DIAG] weight=(6144,768) uint32 scales=(6144,96) biases=(6144,96)
[DIAG] scales				   shape=(6144,96) min=-0.011841 max=0.011292 mean=-0.000052 |mean|=0.006060
[DIAG] biases				   shape=(6144,96) min=-0.104980 max=0.113281 mean=0.000408 |mean|=0.050628
[DIAG] bits=4 group_size=64
[DIAG] dequantized_q_proj		   shape=(6144,6144) min=-0.104980 max=0.113281 mean=0.000008 |mean|=0.015630
[VALS] dequant row0: [-0.0074, -0.0074, -0.0222, -0.0074, 0.0000, 0.0369, 0.0222, 0.0148, 0.0222, 0.0222, 0.0074, 0.0148, 0.0000, -0.0222, -0.0295, 0.0222, -0.0074, -0.0074, -0.0074, 0.0074]
[DIAG] REF(ones)			   shape=(1,1,6144) min=-5.779968 max=6.099457 mean=0.047313 |mean|=1.235951
[VALS] REF(ones): [2.3609, 3.7737, 1.0787, 1.0020, -1.3180, -1.7491, 0.9193, -1.2952, -1.2418, -1.1465, 2.6213, 3.3636, -1.7790, -0.2222, -1.3958, 1.9211, -1.7549, 0.7846, -2.4981, -2.6777]
[DIAG] QMM(ones)			   shape=(1,1,6144) min=-5.781250 max=6.093750 mean=0.047250 |mean|=1.235988
[VALS] QMM(ones): [2.3750, 3.7812, 1.0859, 1.0000, -1.3125, -1.7500, 0.9180, -1.2969, -1.2500, -1.1484, 2.6250, 3.3594, -1.7812, -0.2207, -1.3984, 1.9141, -1.7500, 0.7891, -2.5000, -2.6719]
[DIAG] DIFF(ones)			   shape=(1,1,6144) min=0.000000 max=0.019043 mean=0.003092 |mean|=0.003092
[DIAG] MAX DIFF(ones) = 0.019043
[DIAG] DIFF(random)			   shape=(1,1,6144) min=0.000000 max=0.018132 mean=0.002676 |mean|=0.002676
[DIAG] MAX DIFF(random) = 0.018132
[DIAG] DIFF(batch=3)			   shape=(1,3,6144) min=0.000000 max=0.023514 mean=0.002557 |mean|=0.002557
[DIAG] MAX DIFF(batch=3) = 0.023514

--- TEST 3: RMS Norm ---
[DIAG] rms_norm([1,2,3,4])		   shape=(1,1,4) min=0.365148 max=1.460593 mean=0.912871 |mean|=0.912871
[VALS] rms_norm([1,2,3,4]) expect≈[.365,.730,1.095,1.461]: [0.3651, 0.7303, 1.0954, 1.4606]
[DIAG] rms_norm(rand bf16 4096)		   shape=(1,3,4096) min=-4.093750 max=4.187500 mean=-0.005609 |mean|=0.799228

--- TEST 4: RoPE ---
[DIAG] rope(ones, off=0)		   shape=(1,1,1,128) min=1.000000 max=1.000000 mean=1.000000 |mean|=1.000000
[VALS] rope(ones, off=0): [1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]
[DIAG] rope(ones, off=100)		   shape=(1,1,1,128) min=-1.414062 max=1.406250 mean=0.586235 |mean|=0.953492
[VALS] rope(ones, off=100): [1.3672, 1.3438, -1.3672, -1.3516, 0.7305, -1.3828, -1.4062, -0.9219, 1.3594, -1.1719, 1.3750, -1.1094, -0.5898, 1.2109, 1.1406, -0.0040, -0.9805, -1.3906, -1.3516, -1.0781]

--- TEST 5: Full forward pass ---
terminate called after throwing an instance of 'std::invalid_argument'
  what():  [matmul] Last dimension of first input with shape (1,1,6144) must match second to last dimension of second input with shape (768,151552).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions