Skip to content

Commit 38bc042

Browse files
Properly set comparison type attribute on MLIR comparison ops (#1502)
1 parent bb61c58 commit 38bc042

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

exla/lib/exla/mlir/value.ex

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do
6464
%{type: rhs_type} = get_typespec(rhs)
6565

6666
comparison_type =
67-
if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do
68-
attr_comparison_type(:totalorder)
69-
else
70-
attr_comparison_type(:notype)
67+
cond do
68+
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
69+
attr_comparison_type(:float)
70+
71+
Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
72+
attr_comparison_type(:float)
73+
74+
true ->
75+
attr_comparison_type(:notype)
7176
end
7277

7378
attributes = [
7479
comparison_direction: attr_comparison_direction(direction),
75-
comparison_type: comparison_type
80+
compare_type: comparison_type
7681
]
7782

7883
result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])
@@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do
929934
defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne],
930935
do: attr_enum("stablehlo", "comparison_direction", value)
931936

932-
defp attr_comparison_type(value) when value in [:totalorder, :notype],
937+
defp attr_comparison_type(value) when value in [:float, :totalorder, :notype],
933938
do: attr_enum("stablehlo", "comparison_type", value)
934939

935940
defp attr_precision(value) when value in [:default, :high, :highest],

0 commit comments

Comments
 (0)