淺談GAN生成對抗網路基本原理

Watson Wang
Jul 2, 2021

--

Photo by Ståle Grut on Unsplash

那今天要來探索的部分就是近幾年一個很有趣的機器學習模型 — GAN(Generative Adversarial Network),那我第一次聽到這個名詞是在以前實習時需要做自動化登入,需要訓練簡單辨識驗證碼模型時看到網路上有人使用GAN來做訓練,但當然現在隨意搜尋看到的應用都更有趣,像之前在NVDIA所發佈的GauGAN: Changing Sketches into Photorealistic Masterpieces,就能讓使用者隨手繪圖生成一張漂亮的風景照。

那其他還有文本到圖像的合成,人臉老化,圖像到圖像啊,或是視頻合成等等,有興趣可以到GAN ZOO看一些其他案例,或是玩一些別人的demo網站

風格轉換

https://junyanz.github.io/CycleGAN/

圖像上色

https://twitter.com/quasimondo/status/867023499214413830

那GAN到底是如何做到這件事的我們看一下GAN的內部,主要由兩個model組成:Generator(生成) 跟 Discriminator(辨識),那Generator主要負責生成新的資料,然後交給Discriminator去辨識。常見的一個舉例是做偽鈔,Generator就是做偽鈔的人,而Discriminator就像是驗鈔機,做偽鈔的人如果被discriminator駁回,就回去改進技術做出更加逼真的假鈔(名稱裡對抗的由來),那驗鈔機也會不斷改良變得更嚴格,一段時間後當驗鈔機已經分不出是假鈔還是真鈔,那我們的假鈔生成技術:Genenrator也就訓練完畢了。

這兩個模型呢基本上可以是任何一種神經網路,比如說CNN(convolution neural network),RNN(Recurrent neural network),或是LSTM(long short term memory)。Discriminator很常見,可以參考之前的🔗VGG16或是🔗ResNET都是其中一種。

我們希望的Generator是可以在隨意給入數值後隨機生成不同的圖片,那這就會牽扯到反卷積,我們已經知道Convolution會使圖片特徵縮小,那反卷積的功用恰好相反:使圖片擴增大小,下圖是搭配不同參數(stride/padding)所呈現的擴增方式。

https://github.com/vdumoulin/conv_arithmetic

那看光這樣看可能還是沒感覺,我們拿GAN入門基本上都會做的手寫數字MNIST來做簡單的測試,先看看模型建立,那其中Conv2DTranspose(反卷積)大小的方式是(N-1)*S -2*P +F。(原先圖片邊長:N,Stride大小:S,邊界擴充padding的值:P,kernel大小:F),隨機數列輸入假定是100個數字。

那要如何去定義量兩個model各自的loss,Discriminator的做法是這樣子的,拿真實圖片進去辨識後的數據(得到一個機率),跟1(label)做差值比較,也拿生成圖片進去後的機率跟0(label)比較,拿這兩種加總就是Discriminator的loss。 而Generator的loss則是純粹就自己生產的圖片丟進去discriminator跟0做差值。兩者都是越小越好,今天時間充裕,我們這邊「怒train 3000 epoch」看看效果。

但訓練過程相當坎坷,可以看到loss其實不斷震盪。

100 epoch 印出一次

讓我們來看看3000epoch的成果~

明顯是一個相當漂亮的數字8!那今天就先到這裡,後面有附colab連結,謝謝各位耐心觀看,有興趣可以追蹤看我繼續摸索這不歸路🥲,也可以底下鼓掌~

程式碼(務必記得開GPU)

--

--

Watson Wang
Watson Wang

Written by Watson Wang

AI engineer | M.S. CS student @ NTU (Gmail: john4026191@gmail.com)

No responses yet