Dataset $\mathcal{D}=\{(\mathbf{x}_i, \mathbf{y}i)\}{i=1}^{n}$가 주어지고, 목표로 하는 sparsity level이 $k$인 NN Pruning을 constrained optimization 문제로 표현할 수 있다.
$$ \begin{aligned} \min_{\mathbf{c}, \mathbf{w}}L( \mathbf{w};\mathcal{D})&=\min \frac{1}{n}\sum_{i=1}^{n}\mathcal{l}(\mathbf{w};(\mathbf{x}_i,\mathbf{y}_i)) \\ \text{s.t.} \, \mathbf{w} &\in \mathbb{R}^m, ||\mathbf{\mathbf{w}}||_0 \leq \mathcal{k},
\end{aligned} $$
여기서 $\mathcal{l}(\cdot)$은 cross entropy와 같은 일반적인 loss function이며, $\mathbf{w}$는 NN의 set of paramrter, $m$은 총 parameter 수이다.
정의된 최적화 문제는 보통 미리 학습한 network를 pruning하고 fine-tuning하는 과정을 반복하는 iterative하고 휴리스틱한 방법이 주로 사용됨.
대부분의 pruning 기법이 FCN, CNN 등 architecture에 의존성이 있으며, 또한 pruning 과정에 사용되는 hyper-parameter가 많이 사용하는데 이를 구하는 과정이 휴리스틱한 경우가 많다.
$$ \begin{aligned} \min_{\mathbf{c}, \mathbf{w}}L(\mathbf{c}\odot \mathbf{w};\mathcal{D})&=\min \frac{1}{n}\sum_{i=1}^{n}\mathcal{l}(\mathbf{c}\odot \mathbf{w};(\mathbf{x}_i,\mathbf{y}_i)) \\ \text{s.t.} \, \mathbf{w} &\in \mathbb{R}^m, \\ \mathbf{c} &\in \{0,1\}^m, \, ||\mathbf{c}||_0 \leq \mathcal{k},
\end{aligned} $$
$$ \begin{aligned}\Delta L_j(\mathbf{w};\mathcal{D})&=L(1\odot \mathbf{w};\mathcal{D})-L((1-\mathbf{e}_j)\odot \mathbf{w};\mathcal{D})\end{aligned} $$
weight가 아닌 c에 대한 효과로 다시 표현 가능함. (index $j$의 효과를 제거)
$$ \begin{aligned} \Delta L_j(\mathbf{w};\mathcal{D}) &\approx g_i(\mathbf{w},\mathcal{D}) \\ &=\left. \frac{\partial L(\mathbf{c}\odot \mathbf{w};\mathcal{D})}{\partial c_j} \right|{\mathbf{c}=1} \\ &=\left. \lim{\delta \rightarrow0 }\frac{L(\mathbf{c}\odot \mathbf{w};\mathcal{D})-L((\mathbf{c}-\delta \mathbf{e}j)\odot \mathbf{w};\mathcal{D})}{\delta}\right|{\mathbf{c}=1} \end{aligned} $$
여기서 $c\in\{0,1\}^m$는 미분 불가하므로, 극소 변화에 대한 변화량으로 근사화 함.
weight에 dependency가 적고 한번의 forward pass로 모든 connection을 평가할 수 있는 "connection sensitivity"를 정의하고, 한번의 forward pass를 통해 모든 connection의 sensitivity를 계산함.