@@ -884,37 +884,59 @@ def jointplot(x, y, data=None, kind="scatter", stat_func=stats.pearsonr,
884
884
# Plot the data using the grid
885
885
if kind == "scatter" :
886
886
887
- grid .plot_joint (plt .scatter , color = color , ** joint_kws )
888
- grid .plot_marginals (distplot , kde = False , color = color , ** marginal_kws )
887
+ joint_kws .setdefault ("color" , color )
888
+ grid .plot_joint (plt .scatter , ** joint_kws )
889
+
890
+ marginal_kws .setdefault ("kde" , False )
891
+ marginal_kws .setdefault ("color" , color )
892
+ grid .plot_marginals (distplot , ** marginal_kws )
889
893
890
894
elif kind .startswith ("hex" ):
891
895
892
896
x_bins = _freedman_diaconis_bins (grid .x )
893
897
y_bins = _freedman_diaconis_bins (grid .y )
894
898
gridsize = int (np .mean ([x_bins , y_bins ]))
895
899
896
- grid .plot_joint (plt .hexbin , gridsize = gridsize , cmap = cmap , ** joint_kws )
897
- grid .plot_marginals (distplot , kde = False , color = color , ** marginal_kws )
900
+ joint_kws .setdefault ("gridsize" , gridsize )
901
+ joint_kws .setdefault ("cmap" , cmap )
902
+ grid .plot_joint (plt .hexbin , ** joint_kws )
903
+
904
+ marginal_kws .setdefault ("kde" , False )
905
+ marginal_kws .setdefault ("color" , color )
906
+ grid .plot_marginals (distplot , ** marginal_kws )
898
907
899
908
elif kind .startswith ("kde" ):
900
909
901
- grid .plot_joint (kdeplot , shade = True , cmap = cmap , ** joint_kws )
902
- grid .plot_marginals (kdeplot , shade = True , color = color , ** marginal_kws )
910
+ joint_kws .setdefault ("shade" , True )
911
+ joint_kws .setdefault ("cmap" , cmap )
912
+ grid .plot_joint (kdeplot , ** joint_kws )
913
+
914
+ marginal_kws .setdefault ("shade" , True )
915
+ marginal_kws .setdefault ("color" , color )
916
+ grid .plot_marginals (kdeplot , ** marginal_kws )
903
917
904
918
elif kind .startswith ("reg" ):
905
919
906
920
from .linearmodels import regplot
907
- grid .plot_marginals (distplot , color = color , ** marginal_kws )
908
- grid .plot_joint (regplot , color = color , ** joint_kws )
921
+
922
+ marginal_kws .setdefault ("color" , color )
923
+ grid .plot_marginals (distplot , ** marginal_kws )
924
+
925
+ joint_kws .setdefault ("color" , color )
926
+ grid .plot_joint (regplot , ** joint_kws )
909
927
910
928
elif kind .startswith ("resid" ):
911
929
912
930
from .linearmodels import residplot
913
- grid .plot_joint (residplot , color = color , ** joint_kws )
931
+
932
+ joint_kws .setdefault ("color" , color )
933
+ grid .plot_joint (residplot , ** joint_kws )
934
+
914
935
x , y = grid .ax_joint .collections [0 ].get_offsets ().T
915
- distplot (x , color = color , kde = False , ax = grid .ax_marg_x )
916
- distplot (y , color = color , kde = False , vertical = True ,
917
- fit = stats .norm , ax = grid .ax_marg_y )
936
+ marginal_kws .setdefault ("color" , color )
937
+ marginal_kws .setdefault ("kde" , False )
938
+ distplot (x , ax = grid .ax_marg_x , ** marginal_kws )
939
+ distplot (y , vertical = True , fit = stats .norm , ax = grid .ax_marg_y , ** marginal_kws )
918
940
stat_func = None
919
941
920
942
if stat_func is not None :
0 commit comments