Source code for cosmicfishpie.analysis.plot_comparison

"""
   :synopsis: Module for creating comparison plots of Fisher Matrix entries and related visualizations.
   :module author: Dida Markovic, Santiago Casas, and other contributors to the CosmicFishPie project.
"""

import os

import matplotlib
import matplotlib.patches as mpatches

# matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import seaborn as sns
from mpl_toolkits import axes_grid1

from cosmicfishpie.analysis import fisher_operations as fo
from cosmicfishpie.analysis import utilities as fu
from cosmicfishpie.utilities.utils import printing as upr

dprint = upr.debug_print

plt.style.use("tableau-colorblind10")

snscolors = sns.color_palette("colorblind")
marks = ["o", "s", "*", "X", "p", "^", "v", "8", "D", "P"]
# Colorblind list is 10 elements long, as well as markers list
# This will crash if more than 10 Fishers to compare are requested

matplotlib.rcParams["savefig.transparent"] = False
matplotlib.rcParams["figure.autolayout"] = True
# matplotlib.rcParams['xtick.major.pad']     = 4.0
# matplotlib.rcParams['ytick.major.pad']     = 4.0
# matplotlib.rcParams['axes.labelpad']       = 5.0
# matplotlib.rcParams['savefig.dpi']         = 300
# matplotlib.rcParams['savefig.pad_inches']  = 0.0
# use latex for all text handling. The following fonts
matplotlib.rcParams["text.usetex"] = False
# matplotlib.rcParams['text.latex.preamble'] =
# '\usepackage{amsmath},\usepackage{pgfplots},\usepackage[T1]{fontenc}'
# If True (default), the text will be antialiased.
matplotlib.rcParams["text.antialiased"] = True
# This only affects the Agg backend.


[docs] def calc_y_range(axis, yrang=None): yymin, yymax = axis.get_ylim() yymax = np.max(np.abs([yymin, yymax])) if yymax < 0.1: yymax = 0.1 locat = 0.01 if yymax < 0.3: yymax = 0.3 locat = 0.05 elif yymax < 1.0: yymax = 1.0 locat = 0.1 elif yymax < 5.0: yymax = 5.0 locat = 1.0 elif yymax < 10.0: yymax = 10.0 locat = 5.0 else: yymax = 1.05 * yymax locat = yymax // 5 yymin = -yymax if yrang is not None: locat = (np.abs(np.max(yrang)) + np.abs(np.min(yrang))) / 5 yymax = np.max(yrang) yymin = np.min(yrang) return (yymin, yymax, locat)
[docs] def og_plot_shades( ax, x_arr, x_names, lighty_arr=None, darky_arr=None, mats_labels=None, lightdark_names=["marg.", "unmarg."], cols=[], plotdark=True, plotlight=True, yrang=None, x_limpad=0.2, fish_leg_loc="upper left", LW=2, colordark="darkgrey", colorlight="lightgrey", alpha=0.7, light_hatch="/", patches_legend_loc="upper right", patches_legend_fontsize=16, dots_legend_fontsize=20, ylabelfontsize=20, ncol_legend=None, colors=None, color_palette="colorblind", legend_title_fontsize=None, legend_title=None, y_label="Differences", # r'% differences on ' +r'$\sigma_i$' yticklabsize=18, xticklabsize=18, xtickfontsize=15, xticksrotation=0, ): LW = LW colD = colordark # 'lightslategray' colL = colorlight aalpha = alpha darkgreypatch = mpatches.Patch(color=colD, alpha=aalpha) lightgreypatch = mpatches.Patch(color=colL, alpha=aalpha, hatch=light_hatch) if colors is None: colors = sns.color_palette(color_palette) if lighty_arr is not None: max_l = np.max(lighty_arr, 0) min_l = np.min(lighty_arr, 0) if darky_arr is not None: max_d = np.max(darky_arr, 0) min_d = np.min(darky_arr, 0) if plotlight and lighty_arr is None: print("Error: plotlight is True but lighty_arr is None. Aborting.") return None if plotdark and darky_arr is None: print("Error: plotdark is True but darky_arr is None. Aborting.") return None numarrs = len(mats_labels) nc = ncol_legend if ncol_legend is None: if numarrs < 6: nc = numarrs elif numarrs >= 6: nc = numarrs // 2 for ii, lbl in enumerate(mats_labels): if plotlight: ax.plot( x_arr, lighty_arr[ii, :], marks[ii], c=colors[ii], ms=LW * 8, alpha=aalpha, label=lbl, ) if plotdark: if not plotlight: ms = LW * 8 lbl = lbl else: ms = 0 lbl = None ax.plot( x_arr, darky_arr[ii, :], marks[ii], c=colors[ii], ms=ms, mew=2, alpha=aalpha, label=lbl, ) if plotlight: ax.fill_between( x_arr, min_l, max_l, interpolate=True, facecolor=colL, edgecolor=colL, alpha=aalpha, linewidth=0.0, hatch=light_hatch, ) if plotdark: ax.fill_between( x_arr, min_d, max_d, interpolate=True, facecolor=colD, edgecolor=colD, alpha=aalpha, linewidth=0.0, ) patchlist = [] legpatch = [] if plotlight: patchlist.append(lightgreypatch) legpatch.append(lightdark_names[0]) if plotdark: patchlist.append(darkgreypatch) legpatch.append(lightdark_names[1]) if plotlight or plotdark: leg2 = ax.legend( patchlist, legpatch, loc=patches_legend_loc, ncol=2, fontsize=patches_legend_fontsize ) ax.legend( loc=fish_leg_loc, ncol=nc, fontsize=dots_legend_fontsize, handlelength=2, numpoints=1, title=legend_title, title_fontsize=legend_title_fontsize, ) # bbox_to_anchor=(1.05, 1.05), ax.add_artist(leg2) # ax.axhline(y=0.0, ls=':', c='k', alpha=0.2) # we don't want the zero-line ax.set_ylabel(y_label, labelpad=1, fontsize=ylabelfontsize) ax.set_xlim([min(x_arr) - x_limpad, max(x_arr) + x_limpad]) ax.tick_params(axis="x", direction="in", pad=10, labelsize=xticklabsize) ax.tick_params(axis="y", direction="in", pad=10, labelsize=yticklabsize) ax.set_xticks(x_arr) ax.set_xticklabels(x_names, fontsize=xtickfontsize, rotation=xticksrotation) ax.yaxis.tick_left() ax.yaxis.set_ticks_position("both") ax.xaxis.tick_bottom() ax.xaxis.set_ticks_position("both") ymin, ymax, locaty = calc_y_range(ax, yrang) ax.set_ylim([ymin, ymax]) majorLocator = ticker.MultipleLocator(locaty) minorLocator = ticker.MultipleLocator(locaty) majorFormatter = ticker.FormatStrFormatter("%.2f") ax.yaxis.set_major_locator(majorLocator) ax.yaxis.set_major_formatter(majorFormatter) ax.yaxis.set_minor_locator(minorLocator) return ax
[docs] def plot_shades( ax, x_arr, x_names, lighty_arr=None, darky_arr=None, mats_labels=None, lightdark_names=["marg.", "unmarg."], plotdark=True, plotlight=True, yrang=None, x_limpad=0.2, fish_leg_loc="upper left", LW=2, colordark="darkgrey", colorlight="lightgrey", alpha=0.7, light_hatch="/", patches_legend_loc="upper right", patches_legend_fontsize=16, dots_legend_fontsize=20, ylabelfontsize=20, ncol_legend=None, colors=None, color_palette="colorblind", legend_title_fontsize=None, legend_title=None, y_label="Differences", # r'% differences on ' +r'$\sigma_i$' yticklabsize=18, xticklabsize=18, xtickfontsize=15, xticksrotation=0, ): LW = LW # plt.style.use('tableau-colorblind10') colD = colordark # 'lightslategray' colL = colorlight aalpha = alpha darkgreypatch = mpatches.Patch(color=colD, alpha=aalpha) lightgreypatch = mpatches.Patch(color=colL, alpha=aalpha, hatch=light_hatch) if colors is None: colors = sns.color_palette(color_palette) if lighty_arr is not None: max_l = np.max(lighty_arr, 0) min_l = np.min(lighty_arr, 0) if darky_arr is not None: max_d = np.max(darky_arr, 0) min_d = np.min(darky_arr, 0) if plotlight and lighty_arr is None: print("Error: plotlight is True but lighty_arr is None. Aborting.") return None if plotdark and darky_arr is None: print("Error: plotdark is True but darky_arr is None. Aborting.") return None numarrs = len(mats_labels) nc = ncol_legend if ncol_legend is None: if numarrs < 6: nc = numarrs elif numarrs >= 6: nc = numarrs // 2 if plotlight: dprint("plotting light") ax.bar(x_arr, max_l, color=colL, width=0.8, alpha=0.9, zorder=1) ax.bar(x_arr, min_l, color=colL, width=0.8, alpha=0.9, zorder=1) # ax.fill_between(x_arr, min_l, max_l, interpolate=True, facecolor=colL, # edgecolor=colL, alpha=aalpha, linewidth=0.0, hatch=light_hatch) if plotdark: dprint("plotting dark") ax.bar(x_arr, max_d, color=colD, width=0.5, alpha=0.95, zorder=2) ax.bar(x_arr, min_d, color=colD, width=0.5, alpha=0.95, zorder=2) # ax.fill_between(x_arr, min_d, max_d, interpolate=True, facecolor=colD, # edgecolor=colD, alpha=aalpha, linewidth=0.0) for ii, lbl in enumerate(mats_labels): if plotlight: ax.scatter( x_arr, lighty_arr[ii, :], color=colors[ii], marker=marks[ii], s=(LW * 8) ** 2, label=lbl, alpha=aalpha, zorder=3 + ii, ) if plotdark: if not plotlight: ax.scatter( x_arr, darky_arr[ii, :], color=colors[ii], marker=marks[ii], s=(LW * 8) ** 2, label=lbl, alpha=aalpha - 0.1, zorder=3 + ii, ) else: lbl = None patchlist = [] legpatch = [] if plotlight: patchlist.append(lightgreypatch) legpatch.append(lightdark_names[0]) if plotdark: patchlist.append(darkgreypatch) legpatch.append(lightdark_names[1]) if plotlight or plotdark: leg2 = ax.legend( patchlist, legpatch, loc=patches_legend_loc, ncol=2, fontsize=patches_legend_fontsize ) leg1 = ax.legend( loc=fish_leg_loc, ncol=nc, fontsize=dots_legend_fontsize, handlelength=2, numpoints=1, title=legend_title, title_fontsize=legend_title_fontsize, ) # bbox_to_anchor=(1.05, 1.05), ax.add_artist(leg2) leg1.set_zorder(10) leg2.set_zorder(10) # ax.axhline(y=0.0, ls=':', c='k', alpha=0.2) # we don't want the zero-line ax.set_ylabel(y_label, labelpad=1, fontsize=ylabelfontsize) ax.set_xlim([min(x_arr) - x_limpad, max(x_arr) + x_limpad]) ax.tick_params(axis="x", direction="in", pad=10, labelsize=xticklabsize) ax.tick_params(axis="y", direction="in", pad=10, labelsize=yticklabsize) ax.set_xticks(x_arr) ax.set_xticklabels(x_names, fontsize=xtickfontsize, rotation=xticksrotation) ax.yaxis.tick_left() ax.yaxis.set_ticks_position("both") ax.xaxis.tick_bottom() ax.xaxis.set_ticks_position("both") ymin, ymax, locaty = calc_y_range(ax, yrang) ax.set_ylim([ymin, ymax]) majorLocator = ticker.MultipleLocator(locaty) minorLocator = ticker.MultipleLocator(locaty) majorFormatter = ticker.FormatStrFormatter("%.2f") ax.yaxis.set_major_locator(majorLocator) ax.yaxis.set_major_formatter(majorFormatter) ax.yaxis.set_minor_locator(minorLocator) return ax
[docs] def process_fish_errs( fishers_list, fishers_name, parstoplot=None, parsnames_latex=None, marginalize_pars=True, print_errors=True, compare_to_index=False, transform_latex_dict=dict(), ): # Cycle through files and get the errors and the present parameters print(("Fishers names: ", fishers_name)) for nn, ff in zip(fishers_name, fishers_list): ff.name = nn if parstoplot is None: parstoplot = fishers_list[0].get_param_names() print(("parameters to plot: ", parstoplot)) n_pars = len(parstoplot) x_pars = np.arange(1, n_pars + 1) # marginalize Fishers over parameters not plotted if marginalize_pars: processed_fishers = [fo.marginalise(ff, parstoplot) for ff in fishers_list] else: processed_fishers = [fo.reshuffle(ff, parstoplot) for ff in fishers_list] if parsnames_latex is None: parsnames_latex = processed_fishers[0].get_param_names_latex() # print(parsnames_latex) parsnames_latex_transf = [transform_latex_dict.get(pp, pp) for pp in parsnames_latex] print("X tick labels ---> : ", parsnames_latex_transf) parsnames_latex = ["$" + pp + "$" for pp in parsnames_latex_transf] errMargs = np.array([mm.get_confidence_bounds(marginal=True) for mm in processed_fishers]) errUnmargs = np.array([mm.get_confidence_bounds(marginal=False) for mm in processed_fishers]) if print_errors: for ii, fishy in enumerate(processed_fishers): dprint(("Fisher name: ", fishy.name)) dprint(("Parameter names latex: ", parsnames_latex)) dprint(("Marginalized 1-sigma errors :", errMargs[ii])) dprint(("Unmarginalized 1-sigma errors :", errUnmargs[ii])) # Plot differences, not absolute values np.abs, np.median default if not compare_to_index: eurel = fu.rel_median_error(errUnmargs) emrel = fu.rel_median_error(errMargs) else: if isinstance(compare_to_index, int) and compare_to_index >= 0: eurel = fu.rel_error_to_index(compare_to_index, errUnmargs) emrel = fu.rel_error_to_index(compare_to_index, errMargs) return eurel, emrel, x_pars, parsnames_latex
[docs] def ploterrs( fishers_list, fishers_name, parstoplot=None, parsnames_latex=None, marginalize_pars=True, plot_style="original", outpathfile=os.getcwd(), plot_marg=True, plot_unmarg=True, yrang=None, figsize=(10, 6), fish_leg_loc="lower left", dpi=400, savefig=True, y_label="Errors", ncol_legend=None, colors=None, legend_title_fontsize=None, legend_title=None, yticklabsize=20, xticklabsize=15, patches_legend_fontsize=20, dots_legend_fontsize=20, xtickfontsize=18, ylabelfontsize=20, compare_to_index=False, xticksrotation=0, save_error=False, transform_latex_dict=dict(), figure_title="", ): fig, ax1 = plt.subplots(1, 1, sharey=True, figsize=figsize, facecolor="white") """ Plot the error comparison between different Fisher matrices""" ax1.set_title(figure_title, loc="center") eurel, emrel, x_pars, parsnames_latex = process_fish_errs( fishers_list, fishers_name, parstoplot=parstoplot, parsnames_latex=parsnames_latex, marginalize_pars=marginalize_pars, transform_latex_dict=transform_latex_dict, compare_to_index=compare_to_index, ) # fishnamesjoined=("-").join(fishers_name) if save_error: np.savetxt(outpathfile.replace(".pdf", ".txt"), np.concatenate((eurel, emrel), axis=0)) if plot_style == "original": og_plot_shades( ax1, x_pars, parsnames_latex, mats_labels=fishers_name, lighty_arr=emrel, darky_arr=eurel, lightdark_names=["marg.", "unmarg."], plotlight=plot_marg, plotdark=plot_unmarg, fish_leg_loc=fish_leg_loc, yrang=yrang, y_label=y_label, ncol_legend=ncol_legend, legend_title_fontsize=legend_title_fontsize, legend_title=legend_title, yticklabsize=yticklabsize, xticklabsize=xticklabsize, xtickfontsize=xtickfontsize, ylabelfontsize=ylabelfontsize, xticksrotation=xticksrotation, colors=colors, patches_legend_fontsize=patches_legend_fontsize, dots_legend_fontsize=dots_legend_fontsize, ) elif plot_style == "bars": plot_shades( ax1, x_pars, parsnames_latex, mats_labels=fishers_name, lighty_arr=emrel, darky_arr=eurel, lightdark_names=["marg.", "unmarg."], plotlight=plot_marg, plotdark=plot_unmarg, fish_leg_loc=fish_leg_loc, yrang=yrang, y_label=y_label, ncol_legend=ncol_legend, legend_title_fontsize=legend_title_fontsize, legend_title=legend_title, yticklabsize=yticklabsize, xticklabsize=xticklabsize, xtickfontsize=xtickfontsize, ylabelfontsize=ylabelfontsize, xticksrotation=xticksrotation, colors=colors, patches_legend_fontsize=patches_legend_fontsize, dots_legend_fontsize=dots_legend_fontsize, ) # fig.tight_layout(pad=10.0, w_pad=10.0, h_pad=10.0) plotfile = outpathfile if savefig: fig.savefig(plotfile, dpi=dpi, bbox_inches="tight")
# fig.show()
[docs] def add_colorbar(im, aspect=30, pad_fraction=0.5, **kwargs): """Add a vertical color bar to an image plot.""" divider = axes_grid1.make_axes_locatable(im.axes) width = axes_grid1.axes_size.AxesY(im.axes, aspect=1.0 / aspect) pad = axes_grid1.axes_size.Fraction(pad_fraction, width) current_ax = plt.gca() cax = divider.append_axes("right", size=width, pad=pad) plt.sca(current_ax) return im.axes.figure.colorbar(im, cax=cax, **kwargs)
[docs] def matrix_plot( matrix, xlabel="Ratio", ticklabels=None, filename="matrixplot.png", figsize=(9, 9), colormap=plt.cm.viridis, savefig=True, dpi=200, ): fig, ax = plt.subplots(1, figsize=(9, 9), facecolor="white") intermat = matrix lenmat = intermat.shape[0] if ticklabels is None: ticklabels = ["{:d}".format(ii) for ii in range(lenmat)] for i in range(len(intermat)): for j in range(len(intermat)): c = intermat[j, i] ax.text(i, j, "{:.2f}".format(c), va="center", ha="center", fontsize=11) im = ax.matshow(intermat, cmap=plt.cm.viridis) # cax = fig.add_axes([ax.get_position().x1+0.06, ax.get_position().y0, 0.02, ax.get_position().height]) # plt.colorbar(im, cax=cax) add_colorbar(im) ax.set_xlabel(xlabel) ax.xaxis.set_label_position("top") matplotlib.rcParams["xtick.labelsize"] = 14 matplotlib.rcParams["ytick.labelsize"] = 14 ax.tick_params( axis="both", which="both", labelsize=14, labelbottom=True, bottom=True, top=True, labeltop=False, direction="in", ) ax.set_xticks(np.arange(lenmat), ticklabels) ax.set_yticks(np.arange(lenmat), ticklabels) if savefig: fig.savefig(filename, dpi=dpi, bbox_inches="tight")