章敏敏,徐和平,王曉潔,周夢昀,洪淑月
(浙江師范大學 數理與信息工程學院,浙江 金華 321004)
摘要:TensorFlow是谷歌的第二代開源的人工智能學習系統,是用來實現神經網絡的內置框架學習軟件庫。目前,TensorFlow機器學習已經成為了一個研究熱點。由基本的機器學習算法入手,簡析機器學習算法與TensorFlow框架,并通過在Linux系統下搭建環境,仿真手寫字符識別的TensorFlow模型,實現手寫字符的識別,從而實現TensorFlow機器學習框架的學習與應用。
關鍵詞:TensorFlow;機器學習;應用
中圖分類號:TP181文獻標識碼:ADOI: 10.19358/j.issn.1674-7720.2017.10.017
引用格式:章敏敏,徐和平,王曉潔,等.谷歌TensorFlow機器學習框架及應用[J].微型機與應用,2017,36(10):58-60.
0引言
機器學習是一門多領域交叉的學科,能夠實現計算機模擬或者實現人類的學習行為,重構自己的知識結構從而改善自身的性能。2016年初,AlphaGo以大比分戰勝李世石,AI的概念從此進入人們的視野,而機器學習就是AI的核心,是使計算機具有智能的根本途徑。TensorFlow是谷歌的第二代人工智能學習系統,是用來制作AlphaGo的一個開源的深度學習系統。
1機器學習
可以舉一個簡單的例子來說明機器學習的概念,使用k近鄰算法改進交友網站的配對效果[1]。比如說你現在想要在交友網站上認識一個朋友,而交友網站上擁有每個注冊用戶的兩個信息(玩視頻游戲所耗時間的百分比和每年獲取的飛行常客里程數),你想知道你會對哪些人比較感興趣,這時候就可以使用機器學習算法建立一個簡單的模型。可以將一些自己認為有魅力的人、魅力一般的人、不喜歡的人的這兩個信息(玩視頻游戲所耗時間的百分比和每年獲取的飛行常客里程數)輸入機器學習算法建立一個模型,如圖1所示。當你想知道一個用戶是不是你感興趣交友的人時,輸入信息,計算機通過這個模型進行計算,可以給你一個預測答案,這就是一種經典的監督學習算法。
機器學習算法有很多種類,上述例子說明的監督學習算法只是其中的一類。如果換種方式去實現這個結果,你有一堆如上的數據,但是并不對這些數據進行分類,讓算法按照數據的分散方式來觀察這些數據,發現數據形成了一些聚類,如圖2所示,而通過這種方法,能夠把這些數據自動地分類,這就是一種無監督學習算法。
機器學習的算法有很多,再比如用學習型算法來判斷你需要多少訓練信息,用什么樣的更好的近似函數能夠反映數據之間的關系,使得用最少的訓練信息獲得更準確的判斷。
機器學習就是當機器想要完成一個任務,通過它不斷地積累經驗,來逐漸更好、差錯減少地完成一個任務。
2TensorFlow的框架
2.1TensorFlow輸入張量
TensorFlow的命名來源于本身的運行原理。Tensor(張量)意味著N維數組,Flow(流)意味著基于數據流圖的計算。用MNIST機器學習[23]這個例子來解釋一個用于預測圖片里面的數字的模型。
首先要先獲得一個MNIST數據集,如圖3所示,這個數據集能夠在TensorFlow官網上進行下載。每一個MNIST數據單元由一張包含手寫數字的圖片和一個對應的標簽兩部分組成。把這些圖片設為“xs”,把這些標簽設為“ys”。MNIST數據集擁有60 000行的訓練數據集(mnist.train)和10 000行的測試數據集(mnist.test)。
每一張圖片包含28×28個像素點。可以用一個數字數組來表示這張圖片:把這個數組展開成一個向量,長度是784。在MNIST訓練數據集中,mnist.train.images(訓練數集中的圖片)是一個 [60 000, 784] 的張量,如圖4所示,第一個維度數字用來對應每張圖片,第二個維度數字用來索引每張圖片中的像素點。在此張量里的每一個元素,都表示為某張圖片里的某個像素的介于0和1之間的強度值。
相對應的標簽是從0到9的數字,用來描述給定圖片里表示的數字。每個數字對應著相應位置1,如標簽0表示為[1,0,0,0,0,0,0,0,0,0],因此mnist.train.labels是一個 [60 000, 10] 的數字矩陣,如圖5所示。
如上述的這兩個數組都是二維數組,都是TensorFlow中的張量數據[4],而這些數據就以流的形式進入數據運算的各個節點。而以機器算法為核心所構造的模型就是數據流動的場所。TensorFlow就是一個是文件庫,研究人員和計算機科學家能夠借助這個文件庫打造分析圖像和語音等數據的系統,計算機在此類系統的幫助下,將能夠自行作出決定,從而變得更加智能。
2.2TensorFlow代碼框架
TensorFlow是一個非常靈活的框架,它能夠運行在個人計算機或者服務器的單個或多個CPU和GPU上,甚至是移動設備上。
可以從上面舉例的MNIST機器學習來分析TensorFlow的框架。首先,要構建一個計算的過程。MNIST所用到的算法核心就是softmax回歸算法,這個算法就是通過對已知訓練數據同個標簽的像素加權平均,來構建出每個標簽在不同像素點上的權值,若是這個像素點具有有利的證據說明這張圖片不屬于這類,那么相應的權值為負數,相反若是這個像素擁有有利的證據支持這張圖片屬于這個類,那么權值是正數。
因為輸入往往會帶有一些無關的干擾量,于是加入一個額外的偏置量(bias)。因此對于給定的輸入圖片x它代表的是數字i的證據,可以表示為:
evidencei=∑jWi,jxj+bi(1)
其中Wi,j表示權值的矩陣,xj為給定圖片的像素點,bi代表數字i類的偏置量。
在這里不給出詳細的推導過程,但是可以得到一個計算出一個圖片對應每個標簽的概率大小的計算方式,可以通過如下的代碼來得到一個概率分布:
y=softmax(Wx+b)(2)
建立好一個算法模型之后,算法內輸入的所有可操作的交互單元就像式(2)中的圖片輸入x,為了適應所有的圖片輸入,將其設置為變量占位符placeholder。而像權重W和偏置值b這兩個通過學習不斷修改值的單元設置為變量Variable。
train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
TensorFlow在這一步就是在后臺給描述計算的那張圖里面增添一系列新的計算操作單元用來實現反向傳播算法和梯度下降算法。它返回一個單一的操作,當運行這個操作時,可以用梯度下降算法來訓練模型,微調變量,不斷減少成本,從而建立好一個基本模型。
建立好模型之后,創建一個會話(Session),循環1 000次,每次批處理100個數據,開始數據訓練,代碼如下:
sess= tf.InteractiveSession()
for i in range(1000):
batch_xs,batch_ys=mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
TensorFlow通過數據輸入(Feeds)將張量數據輸入至模型中,而張量Tensor就像數據流一樣流過每個計算節點,微調變量,使得模型更加準確。
通過這個例子,可以管中窺豹了解TensorFlow的框架結構,TensorFlow對于輸入的計算過程在后臺描述成計算圖,計算圖建立好之后,創建會話Session來提交計算圖,用Feed輸入訓練的張量數據,TensorFlow通過在后臺增加計算操作單元用于訓練模型,微調數據,從而完成一個機器的學習任務[5]。
3TensorFlow的應用
TensorFlow的支持列表里沒有Windows,而人們使用的計算機大都是安裝的Windows系統,雖然可以用Docker來實現在Windows上運行,但小問題很多,它支持得最好的還是基于UNIX內核的系統[6],例如Linux,因此選擇Ubuntu 15.10。
安裝成功之后,可以測試一下上述MNIST_sotfmax的模型。在程序中加入可以判斷其預測概率的代碼:
correct_prediction=tf.equal(tf.argmax(y,1), tf.argmax(y_, 1))
當tf.argmax(y, 1)預測值與tf.argmax(y_, 1)正確值相等的時候判斷其為正確的預測:
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
accuracy用來計算預測與完全錯誤判斷之間的距離,也就是正確率,最后將它打印在顯示屏上。
在導入代碼之前,要先給予終端最高權限,不然在導入代碼的時候會顯示權限限制。成功導入代碼后,命令行打印出測試結果的正確率,如圖6所示為0.919 1。當然
這只是最簡單的一個模型,有許多算法模型的正確率可以達到0.997左右。
4結論
TensorFlow是一個很好的利用機器學習算法的框架,而它的優勢在于深度學習系統的構建,雖然在本文中沒有涉及,但是從實驗仿真中可以看到TensorFlow的模型構建簡便,訓練速度快。
參考文獻
[1] HARRINGTON P.機器學習實戰[M].李銳,李鵬,曲亞東,等,譯.北京:人民郵電出版社,2013.
[2] TensorFlow官方文檔中文版[EB/OL].(2015-11-18)[2016-11-25]http://wiki.jikexueyuan.com/project/tensorflowzh/.
[3] TensorFlow官方網站[EB/OL].[2016-11-25]https://www.tensorflow.org/.
[4] TensorFlow架構[EB/OL].(2016-06-12)[2016-11-25]http://blog.csdn.net/stdcoutzyx/article/details/51645396.
[5] Google TensorFlow機器學習框架介紹和使用[EB/OL].(2015-12-15)[2016-11-25]http://blog.csdn.net/sinat_31628525/article/details/50320817.
[6] 張俊,李鑫.TensorFlow平臺下的手寫字符識別[J].電腦知識及技術,2016,12(16):199-201.