提问者:小点点

将分组的seaborn FaceGrid热图数据保存到目录时出现问题


我一直在努力用某种特定的外观将我的图表保存到特定的目录中。

这是示例数据和我目前所尝试的

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}




ltt= ['lt1','lt2']

methods=['method 1', 'method 2', 'method 3', 'method 4']
labels = ['label1','label2']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])

data
Out[25]: 
      ltt    method  labels  dtsi  rtsi  nw_score
0     lt1  method 1  label1     0     0         0
1     lt1  method 1  label1     0    10         1
2     lt1  method 1  label1     0    20         1
3     lt1  method 1  label1     0    30         1
4     lt1  method 1  label1     0    40         1
  ...       ...     ...   ...   ...       ...
1595  lt2  method 4  label2    90    50         0
1596  lt2  method 4  label2    90    60         0
1597  lt2  method 4  label2    90    70         0
1598  lt2  method 4  label2    90    80         0
1599  lt2  method 4  label2    90    90         0




labels_fill = {0:'red',1:'blue'}

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)


for l in data.ltt.unique():

#    print(l)

    with sns.plotting_context(font_scale=5.5):
        g = sns.FacetGrid(data,row="labels", col="method+l", size=2, aspect=1,margin_titles=False)
        g = g.map_dataframe(facet)
        g.add_legend()
       # g.set(xlabel='common xlabel', ylabel='common ylabel')
        #g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)
        g.set_titles(template="")

        for ax,m in zip(g.axes[0,:],methods):
            ax.set_title(m, fontweight='bold', fontsize=12)
        for ax,l in zip(g.axes[:,0],labels):
            ax.set_ylabel(l, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')

     #   g.fig.tight_layout() 

    save_results_to = 'D:/plots'

    if not os.path.exists(save_results_to):
        os.makedirs(save_results_to)


    g.savefig(save_results_to + l+  '.png', dpi = 300)

当我运行上面的代码时,我得到一个错误,它说

ValueError:索引包含重复的条目,无法重塑

所需的图形格式


共1个答案

匿名用户

问题来自这样一个事实,即您试图在两个ltt级别之间循环,但随后您没有在这些级别上过滤数据库。

for l in data.ltt.unique():
    g = sns.FacetGrid(data[data.ltt==l], ....)

此外,您与变量l有冲突,该变量一次用于litt级别,第二次用于行标签。尝试在代码中使用更具描述性的变量名称。

以下是完整的工作代码:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}




ltt= ['lt1','lt2']

methods=['method 1', 'method 2', 'method 3', 'method 4']
labels = ['label1','label2']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])

labels_fill = {0:'red',1:'blue'}

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)


for lt in data.ltt.unique():
    with sns.plotting_context(font_scale=5.5):
        g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
        g = g.map_dataframe(facet)
        g.add_legend()
        g.set_titles(template="")

        for ax,method in zip(g.axes[0,:],methods):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(g.axes[:,0],labels):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        g.fig.suptitle(lt, fontweight='bold', fontsize=12)
        g.fig.tight_layout()
        g.fig.subplots_adjust(top=0.8) # make some room for the title

        g.savefig(lt+'.png', dpi=300)

lt1。巴布亚新几内亚

lt2。巴布亚新几内亚