【Tensorflow/u-net】Tensorflow 2.X系を使ってu-netを実装する

田中太郎
田中太郎

Tensorflowの チュートリアルを参考に u-net の実装を行いました!

概要

Deep-Learning の勉強でTensorflow u-netを実装しました。TenorFlow の公式サイトにMNINSTを使ったチュートリアルがあります。そのコードを基にu-netを実装しました。

環境

・Tensorflow 2.0
・Python 3.7
・CentOS
・入力画像サイズ 128×128
・出力画像サイズ 128×128

入力

訓練用の入力画像(teach_dir)と正解画像(label_dir)、テスト用の入力画像(test_dir)と正解画像(accurate_dir)の4種類の画像を格納するディレクトリを作成します。

ファイル構造を以下になります。

・ー実行ファイル1 <u_net.py>
・ー実行ファイル2 <image2tensor.py>

・ーー訓練用の入力画像(teach_dir)
| |ー画像ファイルが n 個

・ーー訓練用の正解画像(label_dir)
| |ー画像ファイルが n 個

・ーーテスト用の入力画像(test_dir)
| |ー画像ファイルが m 個

・ーーテスト用の正解画像(accurate_dir)
  |ー画像ファイルが m 個

訓練用、テスト用の入力画像の一例は以下になります。

図1:入力画像

訓練用、テスト用の正解画像の一例は以下になります。

図2:正解画像

出力

TensorFlow のチュートリアルにのっとり、Epoch数、損失関数、正解率が出力されます。また、訓練したu-net の出力画像が保存されます。

図3:出力(コンソールの画面)

コード

スクリプトファイルは <u_net.py> と <image2tensor.py> の2つあります。

コード

<u_net.py>

import tensorflow as tf
from tensorflow.keras.layers import MaxPool2D, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras import Model

import numpy as np
import image2tensor


# Read images
train_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "teach_dir", "label_dir"
)    
test_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "test_dir", "accurate_dir"
)    

# Add shuffle buffer_size to tf.data.Dataset object,
# and batch size.
train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

# Create model using model subclassing API in Keras.
# Model is constructure of NN.
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1_1 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_2 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_3 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_4 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv2_1 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_2 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_3 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_4 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv3_1 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_2 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_3 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_4 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv4_1 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_2 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_3 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_4 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv5_1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv5_2 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv_o = Conv2D(1, 1, padding="same", activation='sigmoid')
        self.cat_1 = Concatenate()
        self.cat_2 = Concatenate()
        self.cat_3 = Concatenate()
        self.cat_4 = Concatenate()
        self.maxpool_1 = MaxPool2D((2, 2))
        self.maxpool_2 = MaxPool2D((2, 2))
        self.maxpool_3 = MaxPool2D((2, 2))
        self.maxpool_4 = MaxPool2D((2, 2))
        self.upsampling_1 = UpSampling2D((2, 2))
        self.upsampling_2 = UpSampling2D((2, 2))
        self.upsampling_3 = UpSampling2D((2, 2))
        self.upsampling_4 = UpSampling2D((2, 2))

    def call(self, x):
        x = self.conv1_1(x)
        c1 = self.conv1_2(x)
        x = self.maxpool_1(c1)

        x = self.conv2_1(x)
        c2 = self.conv2_2(x)
        x = self.maxpool_2(c2)

        x = self.conv3_1(x)
        c3 = self.conv3_2(x)
        x = self.maxpool_3(c3)

        x = self.conv4_1(x)
        c4 = self.conv4_2(x)
        x = self.maxpool_4(c4)

        x = self.conv5_1(x)
        x = self.conv5_2(x)

        x = self.upsampling_1(x)
        x = self.cat_1([x, c4])
        x = self.conv4_3(x)
        x = self.conv4_4(x)

        x = self.upsampling_2(x)
        x = self.cat_2([x, c3])
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        x = self.upsampling_3(x)
        x = self.cat_3([x, c2])
        x = self.conv2_3(x)
        x = self.conv2_4(x)

        x = self.upsampling_4(x)
        x = self.cat_4([x, c1])
        x = self.conv1_3(x)
        x = self.conv1_4(x)

        return self.conv_o(x)

model = MyModel()

# Select loss function
loss_object = tf.keras.losses.MeanSquaredError()

# Select optimizer
optimizer = tf.keras.optimizers.Adam()

# Select loss and accurate
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.MeanAbsoluteError(name="train_accuracy")

# Select loss and accurate
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.MeanAbsoluteError(name="test_accuracy")

# @tf.function
def train_step(images, labels):
    # Recode formula by tape
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    # Calculate gradient
    # model.trainable_variables give list of weight
    gradients = tape.gradient(loss, model.trainable_variables)
    
    # Update weight value
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))

    train_loss(loss)
    _ = train_accuracy.update_state(labels, predictions)


# @tf.function
def test_step(images, labels, epoch):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    _ = test_accuracy.update_state(labels, predictions)

    # Save predictions image.
    count = 0
    for img in predictions:
        img = tf.cast(255*img, tf.uint8)
        filename = "{0}_{1}.png".format(epoch, count)
        tf.keras.preprocessing.image.save_img(
            filename, img)
        count += 1


# Set EPOCHS number.
EPOCHS = 100

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels, epoch)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (
        template.format(
            epoch+1,
            train_loss.result(),
            100 - train_accuracy.result().numpy()*100,
            test_loss.result(),
            100 - test_accuracy.result().numpy()*100
        )
    )

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

<image2tensor.py>

import os
import sys

import cv2
import numpy as np
import tensorflow as tf


def detect_images_pathes(images_dir):
    image_names = os.listdir(images_dir)
    images_pathes = [
        os.path.join(images_dir, i) for i in image_names]
    return images_pathes
    

def image_2_tf_data_dataset(teach_imgs, label_imgs):
    # Initaialize ndarray of imgs and labels
    imgs = np.ndarray((
        len(teach_imgs),
        teach_imgs[0].shape[0],
        teach_imgs[0].shape[1],
    ))
    labels = np.ndarray((
        len(label_imgs),
        label_imgs[0].shape[0],
        label_imgs[0].shape[1],
    ))

    for i in range(len(teach_imgs)):
        imgs[i] = imgs[i] + teach_imgs[i]/255.0
        labels[i] = labels[i] + label_imgs[i]/255.0

    imgs = imgs[..., tf.newaxis]
    labels = labels[..., tf.newaxis]
    return tf.data.Dataset.from_tensor_slices((imgs, labels))


def read_image(pathes):
    imgs = []
    for path in pathes:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        imgs.append(img)
    return imgs


def read_dir_and_convert_tf_data_dataset(teach_dir, label_dir):
    # Detect file pathes
    teach_pathes = detect_images_pathes(teach_dir)
    label_pathes = detect_images_pathes(label_dir)

    # Error check
    if len(teach_pathes) != len(label_pathes):
        sys.exit(
"ERROR:input and label data num are not equal. {0} file is {1}, {2} file are {3}".format(teach_dir, len(teach_pathes), label_dir, len(label_pathes)))

    # Read images. Argument is file pathes
    teach_imgs = read_image(teach_pathes)
    label_imgs = read_image(label_pathes)

    # Convert image of numpy to tf.data.dataset
    ds = image_2_tf_data_dataset(
        teach_imgs=teach_imgs,
        label_imgs=label_imgs,
    )

    return ds
コードの解説

<u_net.py> の頭から説明していきます。

import tensorflow as tf  # 
from tensorflow.keras.layers import MaxPool2D, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras import Model

import numpy as np
import image2tensor

必要なモジュールをインポートします。

import image2tensor

は自作モジュール<image2tensor.py> をインポートしています。

# Read images
train_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "teach_dir", "label_dir"
)    
test_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "test_dir", "accurate_dir"
)    

# Add shuffle buffer_size to tf.data.Dataset object,
# and batch size.
train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

画像を読み込んで、TensorFlow で処理する Tensor 型に変換しています。

image2tensor.read_dir_and_convert_tf_data_dataset

は自作モジュールの関数です。画像ファイルが格納されているディレクトリを引数にすると、ディレクトリ内のファイルをすべて読み込んで、TensorFlowのDataset型で返します。TesorFlowのDataset型は詳しくは説明しませんが、入力画像と正解画像をセットで扱えるものです。

注)引数に、入力データと正解データのディレクトリを指定しますが、画像ファイルの数は合わせてください。
ファイル名はなんでも良いですが、入力データと正解データでそろえてください。

train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

.shuffle を設定すると、読み込んだ画像を使うときに、ランダムに返してくれます。
.batch は、例えば8の場合、8セット画像を入力した後に重みを更新させます。バッチ学習と呼ばれるやつです。

# Create model using model subclassing API in Keras.
# Model is constructure of NN.
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1_1 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_2 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_3 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_4 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv2_1 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_2 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_3 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_4 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv3_1 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_2 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_3 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_4 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv4_1 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_2 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_3 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_4 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv5_1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv5_2 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv_o = Conv2D(1, 1, padding="same", activation='sigmoid')
        self.cat_1 = Concatenate()
        self.cat_2 = Concatenate()
        self.cat_3 = Concatenate()
        self.cat_4 = Concatenate()
        self.maxpool_1 = MaxPool2D((2, 2))
        self.maxpool_2 = MaxPool2D((2, 2))
        self.maxpool_3 = MaxPool2D((2, 2))
        self.maxpool_4 = MaxPool2D((2, 2))
        self.upsampling_1 = UpSampling2D((2, 2))
        self.upsampling_2 = UpSampling2D((2, 2))
        self.upsampling_3 = UpSampling2D((2, 2))
        self.upsampling_4 = UpSampling2D((2, 2))

    def call(self, x):
        x = self.conv1_1(x)
        c1 = self.conv1_2(x)
        x = self.maxpool_1(c1)

        x = self.conv2_1(x)
        c2 = self.conv2_2(x)
        x = self.maxpool_2(c2)

        x = self.conv3_1(x)
        c3 = self.conv3_2(x)
        x = self.maxpool_3(c3)

        x = self.conv4_1(x)
        c4 = self.conv4_2(x)
        x = self.maxpool_4(c4)

        x = self.conv5_1(x)
        x = self.conv5_2(x)

        x = self.upsampling_1(x)
        x = self.cat_1([x, c4])
        x = self.conv4_3(x)
        x = self.conv4_4(x)

        x = self.upsampling_2(x)
        x = self.cat_2([x, c3])
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        x = self.upsampling_3(x)
        x = self.cat_3([x, c2])
        x = self.conv2_3(x)
        x = self.conv2_4(x)

        x = self.upsampling_4(x)
        x = self.cat_4([x, c1])
        x = self.conv1_3(x)
        x = self.conv1_4(x)

        return self.conv_o(x)

model = MyModel()

u-net のニューラルネットワークをクラスとして作成しています。

model = MyModel()

上記でモデルを構築し、model(image) のように引数を与えることで、imageがネットワークを伝播していきます。

# Select loss function
loss_object = tf.keras.losses.MeanSquaredError()

# Select optimizer
optimizer = tf.keras.optimizers.Adam()

# Select loss and accurate
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.MeanAbsoluteError(name="train_accuracy")

# Select loss and accurate
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.MeanAbsoluteError(name="test_accuracy")

損失関数を設定します。チュートリアルでは、損失関数はtf.keras.losses.SparseCategoricalCrossentropyを使っています。しかし、画像を出力するときは、画素ごとに正解不正解を調べる必要があるため、tf.keras.lossesMeanSquareErrorを使います。
また、チュートリアルでは正解率にtf.keras.metrics.SparseCategoricalAccuracy を使っています。しかし、こちらも画像を出力するため、tf.keras.metrics.MeanAbsoluteError を使います。

# @tf.function
def train_step(images, labels):
    # Recode formula by tape
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    # Calculate gradient
    # model.trainable_variables give list of weight
    gradients = tape.gradient(loss, model.trainable_variables)
    
    # Update weight value
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))

    train_loss(loss)
    _ = train_accuracy.update_state(labels, predictions)


# @tf.function
def test_step(images, labels, epoch):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    _ = test_accuracy.update_state(labels, predictions)

    # Save predictions image.
    count = 0
    for img in predictions:
        img = tf.cast(255*img, tf.uint8)
        filename = "{0}_{1}.png".format(epoch, count)
        tf.keras.preprocessing.image.save_img(
            filename, img)
        count += 1

@tf.functionをコメントアウトしています。これはイテレータと呼ばれ、定義した関数に機能を追加する(装飾:イタレート)ことができます。これを付けるとと、Tensorflow のバージョン 1.x系として動作します。1.x 系では、Classの中身を見ることができませんが、2.x 系だとprint で中身を簡単に見れます。今回は2.x 系で動かすため、コメントアウトしました。

# Set EPOCHS number.
EPOCHS = 100

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels, epoch)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (
        template.format(
            epoch+1,
            train_loss.result(),
            100 - train_accuracy.result().numpy()*100,
            test_loss.result(),
            100 - test_accuracy.result().numpy()*100
        )
    )

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

EPOCHS で指定した回数だけ学習します。そしてEPOCH 毎に途中結果を出力します。

蛇足

今回は、TensorFlow の公式チュートリアルを参考にu-netを実装しました。
勉強になりました!

コメント

  1. ひろし より:

    大変参考になりました。
    ご質問があります。
    上記の方法で学習したモデルを保存するにはどのようにしたらいいでしょうか?
    model.fit()などのライブラリを使用していないので、学習済みモデルを保存して活用する方法を教えていただけると幸いです。

    • 田中太郎 より:

      コメントありがとうございます。
      .save_weights()、.load_weight()を使用せずにということでしょうか?

  2. ひろし より:

    ご返信いただきありがとうございます。
    すみません、TensorFlow勉強中です。。
    ご指摘いただいたとおり、.save_weights()、.load_weight()でできました。
    ありがとうございました。

    ちなみに、モデルの構築や学習に、model.compile()や model.fit()を使っていないのは、理由があるのでしょうか?

    • 田中太郎 より:

      学習した値が保存できたようで良かったです!
      model.fit,model.compileを使用していないのは、当初参考にしていた公式のチュートリアルで
      使用していなかっただけ、だったと思います。。

  3. ゆき より:

    最近深層学習を勉強しています、大変参考になりました、ありがとうございます!
    質問をしたいですが、
    今回の訓練用、テスト用の正解画像は白黒画像になっているのですが、
    正解画像をカラー画像にしたい場合にはどういう風に書き換えしたらいいのでしょうか?
    (ラベルの分類を実装したい)

    • 田中太郎 より:

      コメントありがとうございます!
      こんな感じでしょうか?
      image2tensor.pyをRGB画像対応版にしてみました。

      • ゆき より:

        ご返信ありがとうございます。
        変更点は
        img = cv2.imread(path, 0)

        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        ですか?

        • 田中太郎 より:

          画像の読み込みはそちらですね!
          読み込んだ画像をテンソルフロー用に変換する
          image_2_tf_data_dataset
          も変わっているのでご注意ください!

タイトルとURLをコピーしました