Tutorial de Deep Learning amb Keras per carregar un model entrenat amb la base de dades MNIST.
el codi complet d’aquest post està disponible en llibreta d’Jupyter
Introducció
en l’anterior tutorial hem entrenat una xarxa neuronal amb Keras per classificar la base de dades MNIST .
En aquest tutorial aprendrem com carregar un model i realitzar classificacions amb el sense haver de realitzar tot el procés d’entrenament.
Importació de llibreries
from keras.models import load_modelimport matplotlib.pyplot as pltimport numpy as npimport tensorflow as tfimport mathfrom random import sample
Descàrrega de la base de dades MNIST
Tot i que el model aquest entrenat ja, no emmagatzema el conjunt de dades de test o entrenament, per tant, cal descarregar-los de nou.
# Para descargar la base de datos MNISTimport tensorflow as tfold_v = tf.logging.get_verbosity()tf.logging.set_verbosity(tf.logging.ERROR)from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("data/MNIST/", one_hot=True)
extracting data / MNIST / train -images-idx3-ubyte.gzExtracting data / MNIST / train-la bels-idx1-ubyte.gzExtracting data / MNIST / t10k-images-idx3-ubyte.gzExtracting data / MNIST / t10k-labels-idx1-ubyte.gz
Guardar les dades de MNIST en diferents variables
# Guardar las imagenes y etiquetas de entrenoimages_train = mnist.train.imageslabels_train = mnist.train.labels# Guardar las imagenes y etiquetas de testimages_test = mnist.test.imageslabels_test = mnist.test.labelsimg_shape = (28, 28)
Importació de model entrenat
el model entrenat en l’anterior tutorial ha estat nomenat modelo.keras. Cal incloure la ruta completa a el model per ser carregat correctament.
El model conté emmagatzemat l’estructura de la xarxa neuronal, el conjunt de matrius de pesos entrenades. A continuació s’inclou la ruta a el model i es carrega.
Si no has entrenat la xarxa neuronal de l’tutorial anterior i guardat el model no és possible carregar en aquest tutorial.
path_model = 'modelo.h5'new_model = load_model(path_model)
Visualització de l’estructura de el model
Un cop carregat el model, no cal definir de nou tota l’estructura de capes, ja està preparat per a realitzar classificacions.
A continuació es mostra el sumari de el model:
new_model.summary()
Predicció amb el model carregat
y_pred2 = new_model.predict(x=images_test)# Obtencion de las etiquetas predichascls_pred2 = np.argmax(y_pred2,axis=1)
Comprovació de la precisió de l’ model de manera manual
# Obtencion de las etiquetas verdaderastrue_labels2 = np.argmax(labels_test,axis=1)# Obtencion vector booleano para ver que posiciones coincidenpossitions = cls_pred2 == true_labels2# Numero de prediciones correctas dividido entre todas las predicionesprecision = sum(possitions)/len(true_labels2)print("Precision : {0}%".format(precision))
Precision: 0,9771%
Funció auxiliar per plotejar imatges
la següent funció serveix per plotejar setembre exemples de la base de dades MNIST, i indicar que nombre es tracta. En el cas que se li passi el que ha predit la xarxa ploteará les imatges predites correctament amb un marc verd i les mal predites amb un rectangle vermell.
def plot_imagenes(imagenes, verdaderas, predichas=None): # Seleccionar 9 indices aleatorios para elegir las imagenes ind = sample(range(len(imagenes)),9) # Tomar las imagenes img = imagenes color = 'green' # Tomar las etiquetas verdaderas y predichas si las hay if predichas is None: etiq = verdaderas else: etiq = verdaderas pred = predichas # Crear la figura con 3x3 sub-plots fig, axes = plt.subplots(3, 3) fig.subplots_adjust(hspace=0.3, wspace=0.3) for i, ax in enumerate(axes.flat): # Plotear imagen. ax.imshow(img.reshape(img_shape), cmap='binary') # Mostrar los numeros verdaderos y predichos if predichas is None: xlabel = "Numero: {0}".format(etiq) else: xlabel = "Numero: {0}, Predicho: {1}".format(etiq, pred) if etiq != pred: color = 'red' ax.spines.set_color(color) ax.spines.set_color(color) ax.spines.set_color(color) ax.spines.set_color(color) color = 'green' # Mostrar los numeros en el eje x ax.set_xlabel(xlabel) # Borrar los ticks del plot ax.set_xticks() ax.set_yticks() plt.show()
Mostra alguns exemples predits
plot_imagenes(imagenes=images_test, verdaderas=true_labels2, predichas=cls_pred2)
by Pedro Fernando Rodenas Perez / GitHub