Tutorial: Menjalankan model TensorFlow di Python

Tutorial ini menunjukkan kepada Anda cara menggunakan model TensorFlow yang diekspor secara lokal untuk mengklasifikasikan gambar.


Tutorial ini hanya berlaku untuk model yang diekspor dari proyek klasifikasi gambar "Umum (kompak)". Jika Anda mengekspor model lain, kunjungi repositori kode sampel kami.


  • Instal Python 2.7+ atau Python 3.6+.
  • Pasang pip.

Selanjutnya, Anda harus menginstal paket berikut:

pip install tensorflow
pip install pillow
pip install numpy
pip install opencv-python

Muat model dan tag Anda

File .zip yang diunduh dari langkah ekspor berisi model.pb dan file labels.txt. File ini mewakili model terlatih dan label klasifikasi. Langkah pertama adalah memuat model ke dalam proyek Anda. Tambahkan kode berikut ke skrip Python baru.

import tensorflow as tf
import os

graph_def = tf.compat.v1.GraphDef()
labels = []

# These are set to the default names from exported models, update as needed.
filename = "model.pb"
labels_filename = "labels.txt"

# Import the TF graph
with tf.io.gfile.GFile(filename, 'rb') as f:
    tf.import_graph_def(graph_def, name='')

# Create a list of labels.
with open(labels_filename, 'rt') as lf:
    for l in lf:

Menyiapkan gambar untuk prediksi

Ada beberapa langkah yang perlu Anda ambil untuk menyiapkan gambar untuk prediksi. Langkah-langkah ini meniru manipulasi gambar yang dijalankan selama pelatihan.

  1. Buka file dan buat gambar di ruang warna BGR

    from PIL import Image
    import numpy as np
    import cv2
    # Load from a file
    imageFile = "<path to your image file>"
    image = Image.open(imageFile)
    # Update orientation based on EXIF tags, if the file has orientation info.
    image = update_orientation(image)
    # Convert to OpenCV format
    image = convert_to_opencv(image)
  2. Jika gambar memiliki dimensi yang lebih besar dari 1600 piksel, panggil metode ini (ditentukan nanti).

    image = resize_down_to_1600_max_dim(image)
  3. Potong persegi tengah terbesar

    h, w = image.shape[:2]
    min_dim = min(w,h)
    max_square_image = crop_center(image, min_dim, min_dim)
  4. Mengubah ukuran persegi ke bawah menjadi 256x256

    augmented_image = resize_to_256_square(max_square_image)
  5. Potong bagian tengah untuk ukuran input tertentu untuk model

    # Get the input size of the model
    with tf.compat.v1.Session() as sess:
        input_tensor_shape = sess.graph.get_tensor_by_name('Placeholder:0').shape.as_list()
    network_input_size = input_tensor_shape[1]
    # Crop the center for the specified network_input_Size
    augmented_image = crop_center(augmented_image, network_input_size, network_input_size)
  6. Tentukan fungsi pembantu. Langkah-langkah di atas menggunakan fungsi bantuan berikut:

    def convert_to_opencv(image):
        # RGB -> BGR conversion is performed as well.
        image = image.convert('RGB')
        r,g,b = np.array(image).T
        opencv_image = np.array([b,g,r]).transpose()
        return opencv_image
    def crop_center(img,cropx,cropy):
        h, w = img.shape[:2]
        startx = w//2-(cropx//2)
        starty = h//2-(cropy//2)
        return img[starty:starty+cropy, startx:startx+cropx]
    def resize_down_to_1600_max_dim(image):
        h, w = image.shape[:2]
        if (h < 1600 and w < 1600):
            return image
        new_size = (1600 * w // h, 1600) if (h > w) else (1600, 1600 * h // w)
        return cv2.resize(image, new_size, interpolation = cv2.INTER_LINEAR)
    def resize_to_256_square(image):
        h, w = image.shape[:2]
        return cv2.resize(image, (256, 256), interpolation = cv2.INTER_LINEAR)
    def update_orientation(image):
        exif_orientation_tag = 0x0112
        if hasattr(image, '_getexif'):
            exif = image._getexif()
            if (exif != None and exif_orientation_tag in exif):
                orientation = exif.get(exif_orientation_tag, 1)
                # orientation is 1 based, shift to zero based and flip/transpose based on 0-based values
                orientation -= 1
                if orientation >= 4:
                    image = image.transpose(Image.TRANSPOSE)
                if orientation == 2 or orientation == 3 or orientation == 6 or orientation == 7:
                    image = image.transpose(Image.FLIP_TOP_BOTTOM)
                if orientation == 1 or orientation == 2 or orientation == 5 or orientation == 6:
                    image = image.transpose(Image.FLIP_LEFT_RIGHT)
        return image

Klasifikasikan gambar

Setelah gambar disiapkan sebagai tensor, kita dapat mengirimkannya melalui model untuk prediksi.

# These names are part of the model and cannot be changed.
output_layer = 'loss:0'
input_node = 'Placeholder:0'

with tf.compat.v1.Session() as sess:
        prob_tensor = sess.graph.get_tensor_by_name(output_layer)
        predictions = sess.run(prob_tensor, {input_node: [augmented_image] })
    except KeyError:
        print ("Couldn't find classification output layer: " + output_layer + ".")
        print ("Verify this a model exported from an Object Detection project.")

Tampilkan hasil

Hasil menjalankan tensor gambar melalui model kemudian perlu dipetakan kembali ke label.

    # Print the highest probability label
    highest_probability_index = np.argmax(predictions)
    print('Classified as: ' + labels[highest_probability_index])

    # Or you can print out all of the results mapping labels to probabilities.
    label_index = 0
    for p in predictions:
        truncated_probablity = np.float64(np.round(p,8))
        print (labels[label_index], truncated_probablity)
        label_index += 1

Langkah berikutnya

