@@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do
64
64
% { type: rhs_type } = get_typespec ( rhs )
65
65
66
66
comparison_type =
67
- if Nx.Type . float? ( lhs_type ) or Nx.Type . float? ( rhs_type ) do
68
- attr_comparison_type ( :totalorder )
69
- else
70
- attr_comparison_type ( :notype )
67
+ cond do
68
+ Nx.Type . complex? ( lhs_type ) or Nx.Type . complex? ( rhs_type ) ->
69
+ attr_comparison_type ( :float )
70
+
71
+ Nx.Type . float? ( lhs_type ) or Nx.Type . float? ( rhs_type ) ->
72
+ attr_comparison_type ( :float )
73
+
74
+ true ->
75
+ attr_comparison_type ( :notype )
71
76
end
72
77
73
78
attributes = [
74
79
comparison_direction: attr_comparison_direction ( direction ) ,
75
- comparison_type : comparison_type
80
+ compare_type : comparison_type
76
81
]
77
82
78
83
result_types = typespecs_to_mlir_types ( [ Typespec . to_type ( typespec , { :pred , 8 } ) ] )
@@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do
929
934
defp attr_comparison_direction ( value ) when value in [ :eq , :lt , :le , :gt , :ge , :ne ] ,
930
935
do: attr_enum ( "stablehlo" , "comparison_direction" , value )
931
936
932
- defp attr_comparison_type ( value ) when value in [ :totalorder , :notype ] ,
937
+ defp attr_comparison_type ( value ) when value in [ :float , : totalorder, :notype ] ,
933
938
do: attr_enum ( "stablehlo" , "comparison_type" , value )
934
939
935
940
defp attr_precision ( value ) when value in [ :default , :high , :highest ] ,
0 commit comments