LKY 只有原創內容的 Blog

今之能者,謂能轉貼,至於魯蛇,皆能轉貼。不原創,何以別乎?

用 NumPy 向量化加速 Python:條件式卷積,變快 250 倍

Lin, Kao-Yuan's Avatar 2023-03-09

  1. 1. 入門 for loop 寫法,速度定義為 1x
  2. 2. 用 NumPy 向量化計算,速度 250x

如果你能看到這篇文章,應該知道什麼叫做「卷積(Convolution)」,我就不解釋了。

難的是什麼?如果我卷積的演算法,是有條件的、不能用 NumPy 的convolve函式、不能用 OpenCV 的filter2D函式、不能用一個 MxN 的卷積核描述,那我該怎麼辦?

要是用迴圈,大家都會寫啊。可是效能很差,怎麼辦?

這篇文章,我要介紹一個技巧,讓你自定義有條件的卷積演算法,變快 250 倍。


需求:

  • 訪問每一個像素點,根據周圍的像素點的中位數,來決定自己的值
  • 如果周圍的中位數大於自己就+1,小於則就-1,自己就是中位數的話不變

為了說明簡潔,以下說明不會放完整程式碼,完整程式碼會放在最後連結。

入門 for loop 寫法,速度定義為 1x

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 方法1: 使用 for 迴圈 (最慢)
def for_loop():
map_prev = get_map()
map_prev = np.pad(map_prev, 1, 'constant', constant_values=0) # 處理邊界問題,在原本的陣列外面包一圈 0
my_map_forloop = get_map()
for y in range(h):
for x in range(w):
# 如果周圍的中位數大於自己就+1,小於則就-1,自己就是中位數的話不變
around9 = map_prev[y:y+3, x:x+3].flatten()
median = np.median(around9)
if median > my_map_forloop[y, x]:
my_map_forloop[y, x] += 1
elif median < my_map_forloop[y, x]:
my_map_forloop[y, x] -= 1
return my_map_forloop

上面這個程式碼,就是一般初學者最直覺,用 for loop 寫出來的程式碼,但是效能很差,因為 for 迴圈一次只能處理一個數,數據頻繁的進出往來 RAM 與 CPU,效能也很差。


用 NumPy 向量化計算,速度 250x

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 方法2: 使用 NumPy 向量化
def numpy_vectorization():
my_map_vectorlize = get_map()
around9_of_my_map = np.zeros((h+2, w+2, 9), dtype=np.uint8) #+2是為了避免邊界問題
# 從左上角開始,順時針方向,旋轉梯式整片賦值
# 作者: Lin Kao-Yuan 林高遠
# 知乎: www.zhihu.com/people/lin-kao-yuan
# 網站: web.ntnu.edu.tw/~60132057A
# around9_of_my_map[Y位置, X位置, 第幾個方向]
around9_of_my_map[0:-2, 0:-2, 0] = my_map_vectorlize
around9_of_my_map[0:-2, 1:-1, 1] = my_map_vectorlize
around9_of_my_map[0:-2, 2: , 2] = my_map_vectorlize
around9_of_my_map[1:-1, 0:-2, 3] = my_map_vectorlize
around9_of_my_map[1:-1, 1:-1, 4] = my_map_vectorlize # 自己
around9_of_my_map[1:-1, 2: , 5] = my_map_vectorlize
around9_of_my_map[2: , 0:-2, 6] = my_map_vectorlize
around9_of_my_map[2: , 1:-1, 7] = my_map_vectorlize
around9_of_my_map[2: , 2: , 8] = my_map_vectorlize

# 裁掉邊界
around9_of_my_map = around9_of_my_map[1:-1, 1:-1, :]

# 計算中位數
median_map = np.median(around9_of_my_map, axis=2).astype(np.uint8)

# 比較中位數和自己的大小
# 如果周圍的中位數大於自己就+1,小於則就-1,自己就是中位數的話不變
my_map_vectorlize = np.where(median_map > my_map_vectorlize, my_map_vectorlize + 1, my_map_vectorlize)
my_map_vectorlize = np.where(median_map < my_map_vectorlize, my_map_vectorlize - 1, my_map_vectorlize)
return my_map_vectorlize

上面這個程式碼,的概念是什麼?

我個人對卷積的理解,就是「以我自己為中心,對周圍的 MxN 個點做處理」,通常 M 與 N 皆為奇數。

那為了向量化,我就這樣做:

  1. 生成一個 MxN 層與原本相同大小的陣列。
  2. 把原本的陣列,旋轉梯式的shift,整片賦值。裡面的每一層,都有「原本陣列的副本+空的邊界」。
  3. 裁掉邊界。
  4. 不論是要計算中位數或是其他演算法,就像一把槍,一次射穿所有的 MxN 層,就可以得到結果。

旋轉梯局部

如何在向量化的同時,實現條件式演算法?:

  1. 對於條件式的演算法(就是你寫成for loop時會有if的那種),就用 np.where() 來寫,每種 case 都是一個向量化的運算。
  2. 不滿足條件的設定,要比較小心,看是要不變、或給 0。反正向量化計算,不管是否符合條件,一定要給值!不能留空。
  3. 每種 case 的結果再組合起來,可能是 +、and、or、max、min、sum 等等,看情況決定。

這篇文章主要是介紹了 NumPy 的向量化計算,用一個…很不生活化(暫時想不到卷積如何生活化)的簡單案例,來說明 NumPy 向量化計算的優點。

這篇文章的完整程式碼,可以在這裡找到:https://gist.github.com/mosdeo/32230317309bab30727c0c76f09b47f0

旋轉梯圖片:该图片由Wolfgang EckertPixabay上发布

本文最后更新于 天前,文中所描述的信息可能已发生改变