Transfer Learning + Grad-CAM with Flask

Watson Wang
14 min readJun 20, 2021

--

在玩過一些神經網路model訓練圖片後,我開始思考,我該如何呈現這個模型,當然用數據說明很足夠,但我想要的是能用更視覺化的方式呈現,比如說一個小型的網站!最好是能夠順便呈現模型學到的特徵。

那決定好目標後,先來分成三大階段再一一實作:

  1. 訓練模型
  2. 呈現Grad-CAM
  3. 架設Flask

訓練模型

很明顯,第一步是決定要訓練什麼,那根據現在每天的新聞,我很直覺地想到可以來簡單訓練一個判斷有無COVID-19的model,那就從這一塊開始吧!

(1) 資料收集

那資料這邊使用kaggle 的COVID-19 chest Xray作為有確診的圖片集,並使用Chest-Xray Images下載普通肺部X光照(我是放置到./dataset/normal/),那如果你點進去資料集,你會發現其實其中圖片不多,而在這樣的情況下,我們就很適合使用 transfer learning。

covid_dataset_path = './'
# read "metadata.csv" and find the target image
csvPath = os.path.sep.join([covid_dataset_path, "metadata.csv"])
df = pd.read_csv(csvPath)
for (i, row) in df.iterrows():
# check the current case is not COVID-19
if row["finding"] != "COVID-19" or row["view"] != "PA":
continue
# set the picture path for reading.
imagePath = os.path.sep.join([covid_dataset_path, "images", row["filename"]])
# if the input image file does not exist
if not os.path.exists(imagePath):
continue
# get the name and set the new location for train
filename = row["filename"].split(os.path.sep)[-1]
outputPath = os.path.sep.join([f"{dataset_path}/covid", filename])
# copy the image to location
shutil.copy2(imagePath, outputPath)

藉上述程式碼,即可將染疫圖片放進名為covid的資料夾中

(2) transfer learning

遷移學習,顧名思義就是將以訓練好的模型,已經學好的某些看待資料的方式,搬移過來協助新模型的訓練。藉由這種方式,我們可以加速模型的收斂,也能降低對電腦的負荷。那這邊使用VGG16作為遷移對象,並在後面添加CNN 加以訓練。

baseModel = VGG16(weights="./vgg16_weight/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5", include_top=False, input_tensor=Input(shape=(224, 224, 3)))
# construct the head of the model that will be placed on top of the
# the base model
headModel = baseModel.output
headModel = AveragePooling2D(pool_size=(4, 4))(headModel)
headModel = Flatten(name="flatten")(headModel)
headModel = Dense(128, activation="relu")(headModel)
headModel = Dropout(0.6)(headModel)
headModel = Dense(2, activation="softmax")(headModel)

model = Model(inputs=baseModel.input, outputs=headModel)
# if set false means the parameters in the layer will not change
for layer in baseModel.layers:
layer.trainable = False

建議可以下載Vgg16本身的h5檔使用,有時候直接呼叫 vgg16(weights=”imagenet”)會出錯。

下載網址:https://github.com/fchollet/deep-learning-models/releases/download/v0.1/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

並在model fit 時加入ImageDataGenerator 協助訓練

imageAug = ImageDataGenerator(rotation_range=15, fill_mode=”nearest”)opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
record = model.fit_generator(
imageAug.flow(trainX, trainY, batch_size=25),
steps_per_epoch=len(trainX) // 25,
validation_data=(testX, testY),
validation_steps=len(testX) // 25,
epochs=10)

那可以看到模型訓練速度很快,當然很大部份可能是因為資料集少,我們來檢視一下其中的數據。

看起來還可以,那訓練模型到這邊告一段落:)

呈現Grad-CAM

CNN通過多次卷積層和池化層以后,它的最后一層卷積层包含了最多的空間信息,再往下就是 FCN 和 softmax了,其中所包含的信息都是人类难以理解的。所以我們會著重在最後一層CNN來掌握圖片特徵。

(1) 什麼是Grad-CAM

在談論 Grad-CAM 前,我們需要先理解CAM,CAM(Class Activation Mapping)的想法是在最後一層GAP(global average pooling)後的每一格pixel會有自己的權重,代表對於該分類的影響度。那權重的求法主要就是使用反向傳播(back propagation),將 softmax 前一層的結果對feature map 做偏微分再將其加總。

back propagation 可參考:

ML Lecture 7: Backpropagation — YouTubewww.youtube.com › watch

但CAM存在一個缺點就是其必須有GAP層,所以2016年Grad-CAM出現了,其主要目的就是希望不用修改模型就能實現CAM

(2) 程式碼

給予圖像array,並輸入最後一層CNN名字(可由model.summary()中看出)

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
# First, we create a model that maps the input image to the activations
# of the last conv layer as well as the output predictions
grad_model = tf.keras.models.Model(
[model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
)
# Then, we compute the gradient of the top predicted class for our input image
# with respect to the activations of the last conv layer
with tf.GradientTape() as tape:
last_conv_layer_output, preds = grad_model(img_array)
print(last_conv_layer_output, last_conv_layer_output.shape )
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
print(class_channel)
# This is the gradient of the output neuron (top predicted)
# with regard to the output feature map of the last conv layer
grads = tape.gradient(class_channel, last_conv_layer_output)

# This is a vector where each entry is the mean intensity of the gradient
# over a specific feature map channel
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
# We multiply each channel in the feature map array
# by "how important this channel is" with regard to the top predicted class
# then sum all the channels to obtain the heatmap class activation
last_conv_layer_output = last_conv_layer_output[0]
heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
# For visualization purpose, we will also normalize the heatmap between 0 & 1
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
return heatmap.numpy()
原始heatmap

然後我們將heatmap轉換色調到RGB,並將其放大到跟原圖一樣做疊圖。

def plot_heatmap(heatmap, img_path):
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
img = cv2.imread(img_path)
fig, ax = plt.subplots()
im = cv2.resize(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB), (img.shape[1], img.shape[0]))
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)

# 以 0.6 透明度繪製原始影像
ax.imshow(im, alpha=0.6)

# 以 0.4 透明度繪製熱力圖
ax.imshow(heatmap, cmap='jet', alpha=0.4)
plt.savefig('1.png',bbox_inches="tight")

可以看到我們的模型是如何抓取特徵的,但我也不確定這樣的特徵是否正確 :)

Grad-CAM程式碼參考:https://keras.io/examples/vision/grad_cam/

結合Flask

(1) 介紹

其實選用flask是一個很直覺的反應,因為我們模型是使用python訓練,故使用python的網頁框架在程式語言上感覺最相容。

但講到python網頁框架,一般人都會想到兩個:Django跟Flask,那因為兩個我大概都只碰過基礎,我這邊選用Flask是因為其較為靈活,Django太過齊全了,較適用於大型商業專案,在這一次實作中可能較為不適合。

Flask 這個套件提供了不少架設網站需要的基本工具,包括路由(Routes)、網頁模板(templates)、權限(authorization)等等的,從架設網站中最簡單的元素到最複雜的應用,Flask 和其衍生而來的套件幾乎都能幫你實現。

(2) 實作

那因為程式碼有點長,所以我這邊會主要部分function有興趣可以再看我貼在下方的github,或參照底下的youtube連結照著做。

程式碼

那我們從主程式開始,可以看到我們呼叫Model().returnModel(),這邊會去呼叫我們之前訓練的模型並載入回傳,然後app.run() 運作,可以在本地瀏覽器輸入 localhost: 7700 開啟網頁,debug = True 代表如果你更動程式碼,網頁會立馬跟著變動。

再來我們看上面的upload_predictions(),在function上面有一行,那決定這個function會在哪裡運作,”/”代表原地( localhost: 7700),如果輸入”/main”,則只有輸入(localhost: 7700/main)才會運作,後面的method代表呼叫方法。

那可以看到function裡面呼叫兩個function,一個是Predict(),一個是showModelDetail()。第一個會去用model訓練,儲存傳入圖片在 ./static/ 並給出機率,第二個則是會去算 heatmap,並將其儲存至 ./static/details/ ,然後我們呼叫 templates資料夾(規定的名字)裡的 index.html 將這些機率跟圖片位置傳過去呈現。

index.html

其中會用到一些css做美編,那我們使用bootstrap裡面已經有的範例模板,按右鍵顯示網頁原始碼後,將其複製過來。那最後的成品就會像這樣。

網頁初始(左),選取圖片並提交後(右)

那今天就先到這裡,有興趣可以多試試其他css模板,或是對網頁擴增其他功能~謝謝您用心看完文章,如果喜歡我的內容,請點個拍手或留言吧!也順手按下follow鍵,隨時追蹤新文章。歡迎不吝賜教🙂~

Github

參考網址

--

--