【Tensorflow/u-net】Tensorflow 2.X系を使ってu-netを実装する RGB画像版

サンプルコード

u_net.py
import tensorflow as tf
from tensorflow.keras.layers import MaxPool2D, Conv2D, UpSampling2D, Concatenate
from tensorflow.keras import Model

import numpy as np
import image2tensor


# Read images
train_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    # "teach_dir", "label_dir", "GRAY"
    "teach_dir", "label_dir", "RGB"  # 変更点
)    
test_ds = image2tensor.read_dir_and_convert_tf_data_dataset(
    # "test_dir", "accurate_dir", "GRAY"
    "test_dir", "accurate_dir", "RGB"  # 変更点
)    

# Add shuffle buffer_size to tf.data.Dataset object,
# and batch size.
train_ds = train_ds.shuffle(32).batch(8)
test_ds = test_ds.batch(8)

# Create model using model subclassing API in Keras.
# Model is constructure of NN.
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1_1 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_2 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_3 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv1_4 = Conv2D(4, 3, padding='same', activation='relu')
        self.conv2_1 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_2 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_3 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv2_4 = Conv2D(8, 3, padding='same', activation='relu')
        self.conv3_1 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_2 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_3 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv3_4 = Conv2D(16, 3, padding='same', activation='relu')
        self.conv4_1 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_2 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_3 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv4_4 = Conv2D(32, 3, padding='same', activation='relu')
        self.conv5_1 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv5_2 = Conv2D(64, 3, padding='same', activation='relu')
        self.conv_o = Conv2D(1, 1, padding="same", activation='sigmoid')
        self.cat_1 = Concatenate()
        self.cat_2 = Concatenate()
        self.cat_3 = Concatenate()
        self.cat_4 = Concatenate()
        self.maxpool_1 = MaxPool2D((2, 2))
        self.maxpool_2 = MaxPool2D((2, 2))
        self.maxpool_3 = MaxPool2D((2, 2))
        self.maxpool_4 = MaxPool2D((2, 2))
        self.upsampling_1 = UpSampling2D((2, 2))
        self.upsampling_2 = UpSampling2D((2, 2))
        self.upsampling_3 = UpSampling2D((2, 2))
        self.upsampling_4 = UpSampling2D((2, 2))

    def call(self, x):
        x = self.conv1_1(x)
        c1 = self.conv1_2(x)
        x = self.maxpool_1(c1)

        x = self.conv2_1(x)
        c2 = self.conv2_2(x)
        x = self.maxpool_2(c2)

        x = self.conv3_1(x)
        c3 = self.conv3_2(x)
        x = self.maxpool_3(c3)

        x = self.conv4_1(x)
        c4 = self.conv4_2(x)
        x = self.maxpool_4(c4)

        x = self.conv5_1(x)
        x = self.conv5_2(x)

        x = self.upsampling_1(x)
        x = self.cat_1([x, c4])
        x = self.conv4_3(x)
        x = self.conv4_4(x)

        x = self.upsampling_2(x)
        x = self.cat_2([x, c3])
        x = self.conv3_3(x)
        x = self.conv3_4(x)

        x = self.upsampling_3(x)
        x = self.cat_3([x, c2])
        x = self.conv2_3(x)
        x = self.conv2_4(x)

        x = self.upsampling_4(x)
        x = self.cat_4([x, c1])
        x = self.conv1_3(x)
        x = self.conv1_4(x)

        return self.conv_o(x)

model = MyModel()

# Select loss function
loss_object = tf.keras.losses.MeanSquaredError()

# Select optimizer
optimizer = tf.keras.optimizers.Adam()

# Select metrics for measurements
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.MeanAbsoluteError(name="train_accuracy")

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.MeanAbsoluteError(name="test_accuracy")

# @tf.function
def train_step(images, labels):
    # Recode formula by tape
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    # Calculate gradient
    # model.trainable_variables give list of weight
    gradients = tape.gradient(loss, model.trainable_variables)
    
    # Update weight value
    optimizer.apply_gradients(
        zip(gradients, model.trainable_variables))

    train_loss(loss)
    _ = train_accuracy.update_state(labels, predictions)


# @tf.function
def test_step(images, labels, epoch):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    _ = test_accuracy.update_state(labels, predictions)

    # Save predictions image.
    count = 0
    for img in predictions:
        if epoch < 60:
            break
        # binary
        max_tmp = np.amax(img)
        min_tmp = np.amin(img)
        img = 255*(img-min_tmp)/(max_tmp-min_tmp)
        img = tf.cast(img, tf.uint8)
        filename = "{0}_{1}.png".format(epoch, count)
        tf.keras.preprocessing.image.save_img(
            filename, img)
        count += 1


# Set EPOCHS number.
EPOCHS = 100

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels, epoch)

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (
        template.format(
            epoch+1,
            train_loss.result(),
            100 - train_accuracy.result().numpy()*100,
            test_loss.result(),
            100 - test_accuracy.result().numpy()*100
        )
    )

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

image2tensor.py
mport os
import sys

import cv2
import numpy as np
import tensorflow as tf


def detect_images_pathes(images_dir):
    """Detect images from directory.
    Input: images_dir
        images_dir(str): images directory
    output: images_pathes
        images_path(list of str): images pathes"""
    image_names = os.listdir(images_dir)
    images_pathes = [
        os.path.join(images_dir, i) for i in image_names]

    return images_pathes
    

def image_2_tf_data_dataset(teach_imgs, label_imgs, color):  # 引数追加
    """Converte ndarray to tensor.
    Input: teach_imgs, label_imgs, color
        teach_imgs(ndarray): (v, h, 1 or 3)
        label_imgs(ndarray): (v, h, 1 or 3)
        color(str): color is GRAY or RGB
    Output: ds
        ds(tensorflow.data.Dataset): Tensor"""
    # Initaialize ndarray of imgs and labels
    if color == "GRAY":  # Gray
        imgs = np.ndarray((
            len(teach_imgs),
            teach_imgs[0].shape[0],
            teach_imgs[0].shape[1],
        ))
        labels = np.ndarray((
            len(label_imgs),
            label_imgs[0].shape[0],
            label_imgs[0].shape[1],
        ))

        for i in range(len(teach_imgs)):
            imgs[i] = imgs[i] + teach_imgs[i]/255.0
            labels[i] = labels[i] + label_imgs[i]/255.0

        imgs = imgs[..., tf.newaxis]
        labels = labels[..., tf.newaxis]
        return tf.data.Dataset.from_tensor_slices((imgs, labels))
    elif color == "RGB":  # RGB画像用
        imgs = np.ndarray((
            len(teach_imgs),
            teach_imgs[0].shape[0],
            teach_imgs[0].shape[1],
            3,
        ))
        labels = np.ndarray((
            len(label_imgs),
            label_imgs[0].shape[0],
            label_imgs[0].shape[1],
            3,
        ))

        for i in range(len(teach_imgs)):
            imgs[i] = imgs[i] + teach_imgs[i]/255.0
            labels[i] = labels[i] + label_imgs[i]/255.0

        return tf.data.Dataset.from_tensor_slices((imgs, labels))
    else:
        sys.exit("ERROR: Set color GRAY or RGB")


def read_image(pathes, color):  # 引数追加
    """Read images from pathes.
    Input: pathes, color
        pathes(str): pathes of image files
        color(str): GRAY or RGB
    Output: imgs
        imgs(list of ndarray): [x y R G B]"""
    imgs = []
    for path in pathes:
        if color == "GRAY":
            img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        elif color == "RGB":
            img = cv2.imread(path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            sys.exit("ERROR: Set color=(GRAY or RGB)")
        imgs.append(img)
    return imgs


def read_dir_and_convert_tf_data_dataset(
    teach_dir, label_dir, color):  # 引数追加
    """Read teach and label images dir,
    and convert to tf_data_dataset.
    Input: teach_dir, label_dir, color
        teach_dir(str): teach images directory
        label_dir(str): label images directory
        color(str): color is GRAY or RGB
    Output: ds
        ds(tf.data.dataset): Tensor"""
    # Detect file pathes
    teach_pathes = detect_images_pathes(teach_dir)
    label_pathes = detect_images_pathes(label_dir)

    # Error check
    if len(teach_pathes) != len(label_pathes):
        sys.exit(
"ERROR:input and label data num are not equal. {0} file is {1}, {2} file are {3}".format(teach_dir, len(teach_pathes), label_dir, len(label_pathes)))

    # Read images. Argument is file pathes
    teach_imgs = read_image(teach_pathes, color) # 引数追加
    label_imgs = read_image(label_pathes, color) # 引数追加

    # Convert image of numpy to tf.data.dataset
    ds = image_2_tf_data_dataset(
        teach_imgs=teach_imgs,
        label_imgs=label_imgs,
        color=color,  # 引数追加
    )

    return ds
実行方法

u_net.py, image2tensor.py, teach_dir/, label_dir/, test_dir/, accurate_dirを同じ階層において

以下を実行する。

python u_net.py

コメント

タイトルとURLをコピーしました