Tutorial de aprendizaxe profunda con KERAS para cargar un modelo adestrado coa base de datos Mnist.
O código completo desta publicación está dispoñible no libro de Jupyter
Introdución
No tutorial anterior, adestramos unha rede neuronal con KERAS para clasificar a base de datos de Mnist.
en Este tutorial imos aprender a cargar un modelo e realizar clasificacións sen ter que realizar todo o proceso de formación.
Importación de bibliotecas
from keras.models import load_modelimport matplotlib.pyplot as pltimport numpy as npimport tensorflow as tfimport mathfrom random import sample
Descargar a base de datos MNIST
Aínda que o modelo xa está adestrado, non almacena o conxunto de datos de proba ou a formación, polo tanto, é necesario descargar de novo.
# Para descargar la base de datos MNISTimport tensorflow as tfold_v = tf.logging.get_verbosity()tf.logging.set_verbosity(tf.logging.ERROR)from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("data/MNIST/", one_hot=True)
Extracción de datos / mnist / tren -Images-idx3-ubyte.gzextraction Data / Mnist / Train-LA Bels-IDX1-Ubyte.gzextracting Data / Mnist / T10k-images-IDX3-UBYTE.gzextracting Data / Mnnt / T10K-Etiquetas-IDX1-UBYTE.GZ
Garda os datos MNNT en diferentes variables
importación de modelo adestrado
O modelo adestrado no tutorial anterior ten el foi nomeado model.keras. É necesario incluír a ruta completa para que o modelo se cargue correctamente.
O modelo contén a estrutura da rede neuronal, o conxunto de matrices de peso adestrado. A continuación, o camiño está incluído para o modelo e cargado.
Se non adestrou a rede neuronal do tutorial anterior e gardou o modelo, non é posible cargalo neste tutorial.
path_model = 'modelo.h5'new_model = load_model(path_model)
Mostrar a estrutura do modelo
Unha vez que o modelo está cargado, é Non é necesario definir que toda a estrutura da capa xa está listo para realizar clasificacións.
O resumo do modelo móstrase a continuación:
new_model.summary()
Predición co modelo cargado
y_pred2 = new_model.predict(x=images_test)# Obtencion de las etiquetas predichascls_pred2 = np.argmax(y_pred2,axis=1)
Comprobación de precisión do modelo manual
# Obtencion de las etiquetas verdaderastrue_labels2 = np.argmax(labels_test,axis=1)# Obtencion vector booleano para ver que posiciones coincidenpossitions = cls_pred2 == true_labels2# Numero de prediciones correctas dividido entre todas las predicionesprecision = sum(possitions)/len(true_labels2)print("Precision : {0}%".format(precision))
Precisión: 0.9771%
función auxiliar Para trazar imaxes
A seguinte función úsase para trazar 9 exemplos da base de datos de Mnist e indicar o que é tratado un número. No caso de que teña previsto que a rede trazará as imaxes previstas correctamente cun marco verde e os malos predicae cun rectángulo vermello.
def plot_imagenes(imagenes, verdaderas, predichas=None): # Seleccionar 9 indices aleatorios para elegir las imagenes ind = sample(range(len(imagenes)),9) # Tomar las imagenes img = imagenes color = 'green' # Tomar las etiquetas verdaderas y predichas si las hay if predichas is None: etiq = verdaderas else: etiq = verdaderas pred = predichas # Crear la figura con 3x3 sub-plots fig, axes = plt.subplots(3, 3) fig.subplots_adjust(hspace=0.3, wspace=0.3) for i, ax in enumerate(axes.flat): # Plotear imagen. ax.imshow(img.reshape(img_shape), cmap='binary') # Mostrar los numeros verdaderos y predichos if predichas is None: xlabel = "Numero: {0}".format(etiq) else: xlabel = "Numero: {0}, Predicho: {1}".format(etiq, pred) if etiq != pred: color = 'red' ax.spines.set_color(color) ax.spines.set_color(color) ax.spines.set_color(color) ax.spines.set_color(color) color = 'green' # Mostrar los numeros en el eje x ax.set_xlabel(xlabel) # Borrar los ticks del plot ax.set_xticks() ax.set_yticks() plt.show()
Mostrar algúns exemplos previstos
plot_imagenes(imagenes=images_test, verdaderas=true_labels2, predichas=cls_pred2)
por contorna Pedro Fernando Pérez / Github