批量显示一些图像(fashion_mnist分类)

def show_imags(n_rows,n_cols,x_data,y_data,class_names):   assert len(x_data)==len(y_data)   assert n_rows*n_cols < len(x_data)   plt.figure(figsize=(n_cols*1.4,n_rows*1.6))   for row in range(n_rows):       for col in range(n_cols):           index = n_cols*row+col           plt.subplot(n_rows,n_cols,index+1)           plt.imshow(x_data[index],cmap = "binary",interpolation="nearest")           plt .axis("off")#不显示坐标尺寸           plt.title(class_names[y_data[index]])   plt.show()class_names = ['T-shirt','Trouser','Pullover','Dress','Coat','Sandal','Shirt',               'Sneaker','Bag','Ankle boot']show_imags(3,5,x_train,y_train,class_names)
相关文章
相关标签/搜索