【Tensorflow/u-net】Tensorflow 2.X系を使ってu-netを実装する

田中太郎
田中太郎

Tensorflowの チュートリアルを参考に u-net の実装を行いました!

概要

Deep-Learning の勉強でTensorflow u-netを実装しました。TenorFlow の公式サイトにMNINSTを使ったチュートリアルがあります。そのコードを基にu-netを実装しました。

環境

・Tensorflow 2.0
・Python 3.7
・CentOS
・入力画像サイズ 128×128
・出力画像サイズ 128×128

入力

訓練用の入力画像(teach_dir)と正解画像(label_dir)、テスト用の入力画像(test_dir)と正解画像(accurate_dir)の4種類の画像を格納するディレクトリを作成します。

ファイル構造を以下になります。

・ー実行ファイル1 <u_net.py>
・ー実行ファイル2 <image2tensor.py>

・ーー訓練用の入力画像(teach_dir)
| |ー画像ファイルが n 個

・ーー訓練用の正解画像(label_dir)
| |ー画像ファイルが n 個

・ーーテスト用の入力画像(test_dir)
| |ー画像ファイルが m 個

・ーーテスト用の正解画像(accurate_dir)
  |ー画像ファイルが m 個

訓練用、テスト用の入力画像の一例は以下になります。

図1:入力画像

訓練用、テスト用の正解画像の一例は以下になります。

図2:正解画像

出力

TensorFlow のチュートリアルにのっとり、Epoch数、損失関数、正解率が出力されます。また、訓練したu-net の出力画像が保存されます。

図3:出力(コンソールの画面)

コード

スクリプトファイルは <u_net.py> と <image2tensor.py> の2つあります。

コード

<u_net.py>

import tensorflow as tf
from tensorflow.keras.layers import MaxPool2D, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras import Model

import numpy as np
import image2tensor


# Read images
train_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "teach_dir", "label_dir"
)    
test_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "test_dir", "accurate_dir"
)    

# Add shuffle buffer_size to tf.data.Dataset object,
# and batch size.
train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

# Create model using model subclassing API in Keras.
# Model is constructure of NN.
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1_1 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_2 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_3 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_4 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv2_1 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_2 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_3 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_4 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv3_1 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_2 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_3 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_4 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv4_1 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_2 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_3 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_4 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv5_1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv5_2 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv_o = Conv2D(1, 1, padding="same", activation='sigmoid')
        self.cat_1 = Concatenate()
        self.cat_2 = Concatenate()
        self.cat_3 = Concatenate()
        self.cat_4 = Concatenate()
        self.maxpool_1 = MaxPool2D((2, 2))
        self.maxpool_2 = MaxPool2D((2, 2))
        self.maxpool_3 = MaxPool2D((2, 2))
        self.maxpool_4 = MaxPool2D((2, 2))
        self.upsampling_1 = UpSampling2D((2, 2))
        self.upsampling_2 = UpSampling2D((2, 2))
        self.upsampling_3 = UpSampling2D((2, 2))
        self.upsampling_4 = UpSampling2D((2, 2))

    def call(self, x):
        x = self.conv1_1(x)
        c1 = self.conv1_2(x)
        x = self.maxpool_1(c1)

        x = self.conv2_1(x)
        c2 = self.conv2_2(x)
        x = self.maxpool_2(c2)

        x = self.conv3_1(x)
        c3 = self.conv3_2(x)
        x = self.maxpool_3(c3)

        x = self.conv4_1(x)
        c4 = self.conv4_2(x)
        x = self.maxpool_4(c4)

        x = self.conv5_1(x)
        x = self.conv5_2(x)

        x = self.upsampling_1(x)
        x = self.cat_1([x, c4])
        x = self.conv4_3(x)
        x = self.conv4_4(x)

        x = self.upsampling_2(x)
        x = self.cat_2([x, c3])
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        x = self.upsampling_3(x)
        x = self.cat_3([x, c2])
        x = self.conv2_3(x)
        x = self.conv2_4(x)

        x = self.upsampling_4(x)
        x = self.cat_4([x, c1])
        x = self.conv1_3(x)
        x = self.conv1_4(x)

        return self.conv_o(x)

model = MyModel()

# Select loss function
loss_object = tf.keras.losses.MeanSquaredError()

# Select optimizer
optimizer = tf.keras.optimizers.Adam()

# Select loss and accurate
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.MeanAbsoluteError(name="train_accuracy")

# Select loss and accurate
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.MeanAbsoluteError(name="test_accuracy")

# @tf.function
def train_step(images, labels):
    # Recode formula by tape
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    # Calculate gradient
    # model.trainable_variables give list of weight
    gradients = tape.gradient(loss, model.trainable_variables)
    
    # Update weight value
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))

    train_loss(loss)
    _ = train_accuracy.update_state(labels, predictions)


# @tf.function
def test_step(images, labels, epoch):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    _ = test_accuracy.update_state(labels, predictions)

    # Save predictions image.
    count = 0
    for img in predictions:
        img = tf.cast(255*img, tf.uint8)
        filename = "{0}_{1}.png".format(epoch, count)
        tf.keras.preprocessing.image.save_img(
            filename, img)
        count += 1


# Set EPOCHS number.
EPOCHS = 100

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels, epoch)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (
        template.format(
            epoch+1,
            train_loss.result(),
            100 - train_accuracy.result().numpy()*100,
            test_loss.result(),
            100 - test_accuracy.result().numpy()*100
        )
    )

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

<image2tensor.py>

import os
import sys

import cv2
import numpy as np
import tensorflow as tf


def detect_images_pathes(images_dir):
    image_names = os.listdir(images_dir)
    images_pathes = [
        os.path.join(images_dir, i) for i in image_names]
    return images_pathes
    

def image_2_tf_data_dataset(teach_imgs, label_imgs):
    # Initaialize ndarray of imgs and labels
    imgs = np.ndarray((
        len(teach_imgs),
        teach_imgs[0].shape[0],
        teach_imgs[0].shape[1],
    ))
    labels = np.ndarray((
        len(label_imgs),
        label_imgs[0].shape[0],
        label_imgs[0].shape[1],
    ))

    for i in range(len(teach_imgs)):
        imgs[i] = imgs[i] + teach_imgs[i]/255.0
        labels[i] = labels[i] + label_imgs[i]/255.0

    imgs = imgs[..., tf.newaxis]
    labels = labels[..., tf.newaxis]
    return tf.data.Dataset.from_tensor_slices((imgs, labels))


def read_image(pathes):
    imgs = []
    for path in pathes:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        imgs.append(img)
    return imgs


def read_dir_and_convert_tf_data_dataset(teach_dir, label_dir):
    # Detect file pathes
    teach_pathes = detect_images_pathes(teach_dir)
    label_pathes = detect_images_pathes(label_dir)

    # Error check
    if len(teach_pathes) != len(label_pathes):
        sys.exit(
"ERROR:input and label data num are not equal. {0} file is {1}, {2} file are {3}".format(teach_dir, len(teach_pathes), label_dir, len(label_pathes)))

    # Read images. Argument is file pathes
    teach_imgs = read_image(teach_pathes)
    label_imgs = read_image(label_pathes)

    # Convert image of numpy to tf.data.dataset
    ds = image_2_tf_data_dataset(
        teach_imgs=teach_imgs,
        label_imgs=label_imgs,
    )

    return ds
コードの解説

<u_net.py> の頭から説明していきます。

import tensorflow as tf  # 
from tensorflow.keras.layers import MaxPool2D, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras import Model

import numpy as np
import image2tensor

必要なモジュールをインポートします。

import image2tensor

は自作モジュール<image2tensor.py> をインポートしています。

# Read images
train_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "teach_dir", "label_dir"
)    
test_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    "test_dir", "accurate_dir"
)    

# Add shuffle buffer_size to tf.data.Dataset object,
# and batch size.
train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

画像を読み込んで、TensorFlow で処理する Tensor 型に変換しています。

image2tensor.read_dir_and_convert_tf_data_dataset

は自作モジュールの関数です。画像ファイルが格納されているディレクトリを引数にすると、ディレクトリ内のファイルをすべて読み込んで、TensorFlowのDataset型で返します。TesorFlowのDataset型は詳しくは説明しませんが、入力画像と正解画像をセットで扱えるものです。

注)引数に、入力データと正解データのディレクトリを指定しますが、画像ファイルの数は合わせてください。
ファイル名はなんでも良いですが、入力データと正解データでそろえてください。

train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

.shuffle を設定すると、読み込んだ画像を使うときに、ランダムに返してくれます。
.batch は、例えば8の場合、8セット画像を入力した後に重みを更新させます。バッチ学習と呼ばれるやつです。

# Create model using model subclassing API in Keras.
# Model is constructure of NN.
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1_1 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_2 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_3 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_4 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv2_1 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_2 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_3 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_4 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv3_1 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_2 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_3 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_4 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv4_1 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_2 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_3 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_4 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv5_1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv5_2 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv_o = Conv2D(1, 1, padding="same", activation='sigmoid')
        self.cat_1 = Concatenate()
        self.cat_2 = Concatenate()
        self.cat_3 = Concatenate()
        self.cat_4 = Concatenate()
        self.maxpool_1 = MaxPool2D((2, 2))
        self.maxpool_2 = MaxPool2D((2, 2))
        self.maxpool_3 = MaxPool2D((2, 2))
        self.maxpool_4 = MaxPool2D((2, 2))
        self.upsampling_1 = UpSampling2D((2, 2))
        self.upsampling_2 = UpSampling2D((2, 2))
        self.upsampling_3 = UpSampling2D((2, 2))
        self.upsampling_4 = UpSampling2D((2, 2))

    def call(self, x):
        x = self.conv1_1(x)
        c1 = self.conv1_2(x)
        x = self.maxpool_1(c1)

        x = self.conv2_1(x)
        c2 = self.conv2_2(x)
        x = self.maxpool_2(c2)

        x = self.conv3_1(x)
        c3 = self.conv3_2(x)
        x = self.maxpool_3(c3)

        x = self.conv4_1(x)
        c4 = self.conv4_2(x)
        x = self.maxpool_4(c4)

        x = self.conv5_1(x)
        x = self.conv5_2(x)

        x = self.upsampling_1(x)
        x = self.cat_1([x, c4])
        x = self.conv4_3(x)
        x = self.conv4_4(x)

        x = self.upsampling_2(x)
        x = self.cat_2([x, c3])
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        x = self.upsampling_3(x)
        x = self.cat_3([x, c2])
        x = self.conv2_3(x)
        x = self.conv2_4(x)

        x = self.upsampling_4(x)
        x = self.cat_4([x, c1])
        x = self.conv1_3(x)
        x = self.conv1_4(x)

        return self.conv_o(x)

model = MyModel()

u-net のニューラルネットワークをクラスとして作成しています。

model = MyModel()

上記でモデルを構築し、model(image) のように引数を与えることで、imageがネットワークを伝播していきます。

# Select loss function
loss_object = tf.keras.losses.MeanSquaredError()

# Select optimizer
optimizer = tf.keras.optimizers.Adam()

# Select loss and accurate
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.MeanAbsoluteError(name="train_accuracy")

# Select loss and accurate
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.MeanAbsoluteError(name="test_accuracy")

損失関数を設定します。チュートリアルでは、損失関数はtf.keras.losses.SparseCategoricalCrossentropyを使っています。しかし、画像を出力するときは、画素ごとに正解不正解を調べる必要があるため、tf.keras.lossesMeanSquareErrorを使います。
また、チュートリアルでは正解率にtf.keras.metrics.SparseCategoricalAccuracy を使っています。しかし、こちらも画像を出力するため、tf.keras.metrics.MeanAbsoluteError を使います。

# @tf.function
def train_step(images, labels):
    # Recode formula by tape
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    # Calculate gradient
    # model.trainable_variables give list of weight
    gradients = tape.gradient(loss, model.trainable_variables)
    
    # Update weight value
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))

    train_loss(loss)
    _ = train_accuracy.update_state(labels, predictions)


# @tf.function
def test_step(images, labels, epoch):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    _ = test_accuracy.update_state(labels, predictions)

    # Save predictions image.
    count = 0
    for img in predictions:
        img = tf.cast(255*img, tf.uint8)
        filename = "{0}_{1}.png".format(epoch, count)
        tf.keras.preprocessing.image.save_img(
            filename, img)
        count += 1

@tf.functionをコメントアウトしています。これはイテレータと呼ばれ、定義した関数に機能を追加する(装飾:イタレート)ことができます。これを付けるとと、Tensorflow のバージョン 1.x系として動作します。1.x 系では、Classの中身を見ることができませんが、2.x 系だとprint で中身を簡単に見れます。今回は2.x 系で動かすため、コメントアウトしました。

# Set EPOCHS number.
EPOCHS = 100

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels, epoch)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (
        template.format(
            epoch+1,
            train_loss.result(),
            100 - train_accuracy.result().numpy()*100,
            test_loss.result(),
            100 - test_accuracy.result().numpy()*100
        )
    )

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

EPOCHS で指定した回数だけ学習します。そしてEPOCH 毎に途中結果を出力します。

蛇足

今回は、TensorFlow の公式チュートリアルを参考にu-netを実装しました。
勉強になりました!

コメント

タイトルとURLをコピーしました