์๋ก
์ค๋๋ ์ธ๊ณต์ง๋ฅ(AI)์ ์ธ์์์ ์ฌ์ฉ์ ํ๋ผ์ด๋ฒ์ ๋ณดํธ๋ ์ต์ฐ์ ๊ณผ์ ๊ฐ ๋์์ต๋๋ค. ํนํ ๋ฐ์ดํฐ ๊ด๋ จ ๋ํ ๋ฒ์ธ GDPR์์๋ โ์ํ์ง ๊ถ๋ฆฌโ๋ฅผ ๋ช ์ํ๊ณ ์์ผ๋ฉฐ, ๊ฐ์ธ์ด ์์ ์ ๊ฐ์ธ ๋ฐ์ดํฐ ์ญ์ ๋ฅผ ์์ฒญํ ์ ์๋ ๊ถ๋ฆฌ๋ฅผ ๋ถ์ฌํฉ๋๋ค.
๋จธ์ ์ธ๋ฌ๋(Machine Unlearning)์ โ์ฒ์๋ถํฐ ๋ค์ ํ๋ จํ์ง ์๊ณ ๋ AI ๋ชจ๋ธ์ด ํน์ ํ๋ จ ๋ฐ์ดํฐ๋ฅผ ์์ ์ ์์๊น?โ๋ผ๋ ์ง๋ฌธ์ ๋ตํฉ๋๋ค. ์ด๋ ํ๋ผ์ด๋ฒ์ ๊ท์ ์ค์๋ฟ๋ง ์๋๋ผ ์ค๋ ๋ ๋ฐ์ดํฐ ์ ๊ฑฐ, ํ๋ จ ์ค์ ์์ , ์ ๋ขฐํ ์ ์๋ AI ์์คํ ๊ตฌ์ถ์ ํ์์ ์ ๋๋ค.
โUnlearning-Aware Minimization (UAM)โ [Paper]์ ๋จธ์ ์ธ๋ฌ๋์ ์ํ ์๋ก์ด min-max ์ต์ ํ ํ๋ ์์ํฌ์ ๋๋ค. ์ด ๊ธ์์๋ ๋จธ์ ์ธ๋ฌ๋์ ๊ฐ๋ ๊ณผ ์ ์๋ UAM ๋ฐฉ๋ฒ์ ํจ๊ป ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์ฌ์ ์ง์
๋จธ์ ์ธ๋ฌ๋ ๋ฌธ์
๋จธ์ ์ธ๋ฌ๋์ ์ดํดํ๊ธฐ ์ํด์๋ ๋จผ์ ๋ชจ๋ธ์ด โ์๊ธฐโ ํ์ ๋ฌ์ฑํด์ผ ํ ๋ชฉํ๋ฅผ ์ดํดํด์ผ ํฉ๋๋ค. ํ๋ จ ๋ฐ์ดํฐ์ \(\mathcal{D}\)๊ฐ ์ฃผ์ด์ก์ ๋, ์ด๋ฅผ ๋ ๊ฐ์ ์๋ก์ ์งํฉ์ผ๋ก ๋ถํ ํ ์ ์์ต๋๋ค:
- ์์ ๋ฐ์ดํฐ(Forget data) \(\mathcal{D}_f\): ๋ชจ๋ธ์ด ์์ด์ผ ํ ๋ฐ์ดํฐ
- ์ ์งํ ๋ฐ์ดํฐ(Retain data) \(\mathcal{D}_r\): ๋ชจ๋ธ์ด ๊ธฐ์ตํด์ผ ํ ๋ฐ์ดํฐ
์ ํํ ์ธ๋ฌ๋(Exact unlearning)์ด๋ผ๊ณ ๋ถ๋ฆฌ๋ ์ด์์ ์ธ ํด๊ฒฐ์ฑ ์ ์ ์งํ ๋ฐ์ดํฐ๋ง์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ฒ์๋ถํฐ ๋ค์ ํ๋ จํ๋ ๊ฒ์ ๋๋ค: \(\begin{equation} w^* = \text{argmin}_{w} \mathcal{L}(w, \mathcal{D}_r), \label{eq:retrain} \end{equation}\)
์ฌ๊ธฐ์ \(\mathcal{L}(w, \mathcal{D})\)๋ ์์คํจ์๋ฅผ ๋ํ๋ด๊ณ \(w\)๋ ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋ฅผ ์๋ฏธํฉ๋๋ค.
ํ์ง๋ง ๋๊ท๋ชจ ๋ชจ๋ธ์ ๊ฒฝ์ฐ ์ฒ์๋ถํฐ ๋ค์ ํ๋ จํ๋ ๊ฒ์ ๊ณ์ฐ์ ์ผ๋ก ๋ถ๋ด์ด ํฝ๋๋ค. ์๋ฅผ ๋ค์ด, CIFAR-10์์ ResNet ๋ชจ๋ธ์ ๋ค์ ํ๋ จํ๋ ๋ฐ 30๋ถ ์ด์์ด ๊ฑธ๋ฆฌ๋ฉฐ, ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์ ๊ฒฝ์ฐ ๋ฉฐ์น ๋๋ ๋ช ์ฃผ๊ฐ ์์๋ ์ ์์ต๋๋ค. ์ด๋ก ์ธํด ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋งค๊ฐ๋ณ์๋ฅผ ํจ์จ์ ์ผ๋ก ์ ๋ฐ์ดํธํ๋ ๊ทผ์ฌ ์ธ๋ฌ๋(Approximate unlearning) ๋ฐฉ๋ฒ๋ค์ด ๊ฐ๋ฐ๋์์ต๋๋ค.
๊ธฐ์กด ๊ทผ์ฌ ์ธ๋ฌ๋ ๋ฐฉ๋ฒ๋ค
๊ธฐ์กด ์ฐ๊ตฌ์์๋ ๋ ๊ฐ์ง ์ฃผ์ ์ ๊ทผ๋ฒ์ ์ ์ํ์ต๋๋ค:
1. Fine-Tuning (FT): ์ ์งํ ๋ฐ์ดํฐ์ ๋ํด์๋ง ํ๋ จ์ ๊ณ์ํฉ๋๋ค \(\begin{equation} \min_w \mathcal{L}(w, \mathcal{D}_r) \end{equation}\)
2. Negative Gradient (NG): ์์ ๋ฐ์ดํฐ์ ๋ํ ์์ค์ ์ต๋ํํ์ฌ ์ธ๋ฌ๋ํฉ๋๋ค \(\begin{equation} \max_w \mathcal{L}(w, \mathcal{D}_f) \end{equation}\)
FT๋ ์ ์งํ ๋ฐ์ดํฐ์ ๋ํ ์ข์ ์ฑ๋ฅ์ ์ ์งํ์ง๋ง, ์ข ์ข ์์ ๋ฐ์ดํฐ์ ์ํฅ์ ์ถฉ๋ถํ ์ ๊ฑฐํ์ง ๋ชปํฉ๋๋ค. ๋ฐ๋ฉด NG๋ ์์ ๋ฐ์ดํฐ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ์ ๊ฑฐํ์ง๋ง ์ ์งํ ๋ฐ์ดํฐ์ ๋ํ ์ฑ๋ฅ์ ์ฌ๊ฐํ๊ฒ ์ ํ์ํต๋๋ค. ๋ ๋ฐฉ๋ฒ ๋ชจ๋ ์ฐจ์ ์ฑ ์ ์๋ ดํฉ๋๋ค.
๋ณธ๋ก
๋ณธ ๋ ผ๋ฌธ์ โ์ ์งํ ๋ฐ์ดํฐ์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์ ํน์ ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์์ ์ ์๋๊ฐ?โ๋ผ๋ ๊ทผ๋ณธ์ ์ธ ์ง๋ฌธ์ ๋ค๋ฃน๋๋ค. ์ฐ๋ฆฌ๋ ์๋ก์ด min-max ์ต์ ํ ํ๋ ์์ํฌ์ธ Unlearning-Aware Minimization (UAM)์ ์ ์ํฉ๋๋ค.
UAM์ ๋๊ธฐ๋ฅผ ์ดํดํ๊ธฐ ์ํด, ๋จผ์ ๊ธฐ์กด ๊ทผ์ฌ ์ธ๋ฌ๋ ๋ฐฉ๋ฒ๋ค์ ํน์ฑํํ๋ ํตํฉ ๋ชฉ์ ํจ์๋ฅผ ์ค์ ํฉ๋๋ค. ์ต์ ํด \(w^*\)๊ฐ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ์ ๊ณ ๊ทผ๋ฐฉ์ ์๋ค๊ณ ๊ฐ์ ํ ๋, ์ธ๋ฌ๋ ๋ฌธ์ ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ๊ณต์ํํ ์ ์์ต๋๋ค:
\[\begin{equation} \min_w \mathcal{L}(w, \mathcal{D}_r) + \beta \big[ \mathcal{L}(w^*, \mathcal{D}) - \mathcal{L}(w, \mathcal{D}) \big] \label{eq:unified} \end{equation}\]์ ๋ชฉ์ ํจ์๋ ๋ ๊ฐ์ง ๊ตฌ์ฑ์์๋ฅผ ๊ฐ์ง๋๋ค:
- ์ฑ๋ฅ ํญ: \(\mathcal{L}(w, \mathcal{D}_r)\)์ ๋ชจ๋ธ์ด ์ ์งํ ๋ฐ์ดํฐ์ ๋ํ ์ฑ๋ฅ์ ์ ์งํ๋๋ก ์ฅ๋ คํฉ๋๋ค
- ์ผ๊ด์ฑ ํญ: \(\mathcal{L}(w^*, \mathcal{D}) - \mathcal{L}(w, \mathcal{D})\)์ ์ต์ ํ๋ ๊ฐ์ค์น์ ์ต์ ํด ๊ฐ์ ์ ๋ ฌ์ ์ฅ๋ คํฉ๋๋ค
์๋ ํ๋ ์์ํฌ๋ ๊ทผ์ฌ ์ธ๋ฌ๋์ ๋ ๊ฐ์ง ์ฃผ์ ์ ๊ทผ๋ฒ์ ์ค๋ช ํฉ๋๋ค:
- FT: \(\beta = 0\)์ผ๋ก ์ค์ ํ์ฌ ์ผ๊ด์ฑ ํญ์ ๋ฌด์ํ๊ณ ๊ฒฐ๊ณผ์ ์ผ๋ก ์๊ธฐ ์ฑ๋ฅ์ด ์ ํ๋ฉ๋๋ค
- NG: \(w^*\)์ ๋ํ ์ง์์ด ์๋ค๊ณ ๊ฐ์ ํ๊ณ , ์๊ธฐ ์์ค ์ต๋ํ์๋ง ์ง์คํ์ฌ ์ ์ง ์ฑ๋ฅ์ ์ ์งํ๋ ๋ฐ ์ด๋ ค์์ ๊ฒช์ต๋๋ค
UAM์ ํต์ฌ ํต์ฐฐ์ \(w^*\)์ ํน์ง์ธ ๋์ ์๊ธฐ ์์ค์ ํ์ฉํ์ฌ ์ธ๋ฌ๋์ ๋ฌ์ฑํ๋ ๊ฒ์ ๋๋ค. ์ด๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ๊ณต์ํํฉ๋๋ค:
\[\begin{equation} \min_w \mathcal{L}(\text{argmax}_{\|\delta\|_2 \leq \rho} \mathcal{L}(w + \delta, \mathcal{D}_f), \mathcal{D}_r) \end{equation}\]์ด min-max ์ต์ ํ๋ ๋ ๋จ๊ณ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค:
- ๋ด๋ถ ์ต๋ํ: ์๊ธฐ ์์ค์ ์ต๋ํํ๋ ๊ต๋๋ ๋งค๊ฐ๋ณ์ \(\hat{w}=\text{argmax}_{\|\delta\|_2 \leq \rho} \mathcal{L}(w + \delta, \mathcal{D}_f)\)๋ฅผ ์ฐพ์ต๋๋ค
- ์ธ๋ถ ์ต์ํ: \(\hat{w}\)์์ ์ ์ง ์์ค์ ์ต์ํํ๋ ๊ธฐ์ธ๊ธฐ๋ก ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฐ์ดํธํฉ๋๋ค
๋์ ์๊ธฐ ์์ค์ ๊ฐ์ง ๋งค๊ฐ๋ณ์๋ฅผ ์ฐธ์กฐ์ ์ผ๋ก ์ฌ์ฉํจ์ผ๋ก์จ, UAM์ ๋ฎ์ ์ ์ง ์์ค์ ์ ์งํ๋ฉด์ ๋์ ์๊ธฐ ์์ค ํน์ฑ์ ๋ํ๋ด๋ ์ ๋ฐ์ดํธ๋ ๋ชจ๋ธ์ ๋ณด์ฅํฉ๋๋ค.
1์ฐจ ๊ทผ์ฌ๋ฅผ ํตํ ํจ์จ์ ์ธ ์๊ณ ๋ฆฌ์ฆ
๋ด๋ถ ์ต๋ํ ๋ฌธ์ ์ ์ ํํ ํด๋ฅผ ๊ณ์ฐํ๋ ๊ฒ์ ๊ณ์ฐ์ ์ผ๋ก ๋น์ฉ์ด ๋ง์ด ๋ญ๋๋ค. ํจ์จ์ ์ธ ์๊ณ ๋ฆฌ์ฆ์ ๋์ถํ๊ธฐ ์ํด 1์ฐจ ํ ์ผ๋ฌ ๊ทผ์ฌ๋ฅผ ์ ์ฉํฉ๋๋ค:
\[\begin{equation} \min_w \mathcal{L}\left(w + \rho \frac{\nabla_w \mathcal{L}(w, \mathcal{D}_f)}{\|\nabla_w \mathcal{L}(w, \mathcal{D}_f)\|_2^2}, \mathcal{D}_r\right) \end{equation}\]์ด ๊ณต์ํ๋ PyTorch์ ๊ฐ์ ์๋ ๋ฏธ๋ถ ํ๋ ์์ํฌ๋ฅผ ์ฌ์ฉํ์ฌ ํจ์จ์ ์ผ๋ก ๊ตฌํํ ์ ์์ต๋๋ค. UAM์ ์ฃผ์ ์ฅ์ ์ ํ๋ ์์ํฌ ๋ ๋ฆฝ์ ํน์ฑ์ ๋๋ค. UAM์ ๋ค๋ฅธ ์ธ๋ฌ๋ ๋ฐฉ๋ฒ๋ค๊ณผ ์ฝ๊ฒ ํตํฉ๋ ์ ์์ต๋๋ค.
์ด๋ก ์ ๋ถ์
์ฐ๋ฆฌ์ ์ด๋ก ์ ๋ถ์์ UAM ๋ชฉ์ ํจ์์ ๊ธฐ์ธ๊ธฐ๊ฐ ๋ค์๊ณผ ๊ฐ์ด ํํ๋ ์ ์์์ ๋ณด์ฌ์ค๋๋ค:
\[\begin{equation} \nabla_{w} \mathcal{L}(w + \delta(w), \mathcal{D}_r) = \left[\mathbf{I} + \frac{\rho}{\|\nabla_w \mathcal{L}(w, \mathcal{D}_f)\|_2^2}(\mathbf{I} - 2\mathbf{P}_f) \mathbf{H}_f\right]\nabla_{w} \mathcal{L}(w, \mathcal{D}_r)|_{w+\delta(w)} \end{equation}\]์ฌ๊ธฐ์ \(\mathbf{P}_f\)๋ ์๊ธฐ ๊ธฐ์ธ๊ธฐ ๋ฐฉํฅ์ผ๋ก์ ์ง๊ต ํฌ์ ํ๋ ฌ์ด๊ณ , \(\mathbf{H}_f\)๋ ์๊ธฐ ์์ค์ ํค์์์ ๋๋ค.
ํต์ฌ ํต์ฐฐ์ \((\mathbf{I} - 2\mathbf{P}_f)\) ํญ์ผ๋ก, ์ด๋ ์๊ธฐ ๊ธฐ์ธ๊ธฐ์ ์ ๋ ฌ๋ ์ ์ง ๊ธฐ์ธ๊ธฐ์ ์ฑ๋ถ์ ๋ ๋ฒ ๋นผ๋ ์ํ์ ์ฐ์ฐ์ ๋๋ค. ์ด๋ ์ ๋ฐ์ดํธ ๋ฐฉํฅ์ด ์๊ธฐ ์์ค์ ๊ฐ์์ํฌ ๋ฐฉํฅ์์ ๋ฉ์ด์ง๋๋ก ๋ณด์ฅํฉ๋๋ค.
์ ํํ ํค์์ ํ๋ ฌ \(\mathbf{H}_f\)๋ฅผ ๊ณ์ฐํ๋ ๊ฒ์ ๊ณ์ฐ์ ์ผ๋ก ๋น์ฉ์ด ๋ง์ด ๋ญ๋๋ค. ํค์์์ ๋จ์ ํ๋ ฌ๋ก ๊ทผ์ฌํ๋ ๊ฒ์ด ๊ฐ๋จํ๋ฉด์๋ ํจ๊ณผ์ ์ธ ํด๊ฒฐ์ฑ ์์ ๋ฐ๊ฒฌํ์ต๋๋ค:
\[\begin{equation} \left[\mathbf{I}- \gamma\mathbf{P}_f \right]\nabla_{w} \mathcal{L}(w, \mathcal{D}_r)|_{w+\delta(w)} \end{equation}\]์ฌ๊ธฐ์ \(\gamma\)๋ ํ์ดํผํ๋ผ๋ฏธํฐ์ ๋๋ค. ์ค์ ๋ก \(\gamma=2\)๊ฐ ์ต์ํ์ ์ถ๊ฐ ๊ณ์ฐ ๋น์ฉ์ผ๋ก ๋ค์ํ ์์ ์์ ์ผ๊ด๋๊ฒ ๊ฐํ ์ฑ๋ฅ์ ๋ฌ์ฑํจ์ ๋ฐ๊ฒฌํ์ต๋๋ค.
๋ ๊น์ ๋ถ์์ UAM์ ๊ธฐํํ์ ์ง๊ด์ ๋ณด์ฌ์ค๋๋ค. ์ฐ๋ฆฌ์ ๋ชฉ์ ํจ์์ 1์ฐจ ๊ทผ์ฌ๋ฅผ ์ ์ฉํ๋ฉด:
\[\begin{equation} \min_w \mathcal{L}(w, \mathcal{D}_r) + \nabla \mathcal{L}(w, \mathcal{D}_r)^\top \rho \frac{\nabla\mathcal{L}(w, \mathcal{D}_f)}{\|\nabla\mathcal{L}(w, \mathcal{D}_f)\|_2^2} \end{equation}\]์ด๋ UAM์ด ์ ์ง ๊ธฐ์ธ๊ธฐ \(\nabla \mathcal{L}(w, \mathcal{D}_r)\)์ ์๊ธฐ ๊ธฐ์ธ๊ธฐ \(\nabla \mathcal{L}(w, \mathcal{D}_f)\) ๊ฐ์ ๋ด์ (์ฝ์ฌ์ธ ์ ์ฌ๋)์ ๋ช ์์ ์ผ๋ก ์ต์ํํจ์ ๋ณด์ฌ์ค๋๋ค.
์ด๋ฌํ ๊ธฐ์ธ๊ธฐ๋ค์ด ์์ ์ ๋ ฌ(์ฝ์ฌ์ธ ์ ์ฌ๋ โค 0)์ ๊ฐ์ง ๋, ์ ์ง ์์ค์ ์ต์ํํ๋ ๊ฒ์ด ๋ณธ์ง์ ์ผ๋ก ์๊ธฐ ์์ค์ ์ต๋ํ๋ก ์ด์ด์ง๋๋ค. ์ด๋ UAM์ด ๊ธฐ์กด ๋ฐฉ๋ฒ๋ค์ ๋ฅ๊ฐํ๋ ์ด์ ์ ๋ํ ๊ฐ๋ ฅํ ๊ธฐํํ์ ์ค๋ช ์ ์ ๊ณตํฉ๋๋ค.
์ด๋ฏธ์ง ๋ถ๋ฅ ์ธ๋ฌ๋์์์ ์ฑ๋ฅ
UAM์ ์ธ ๊ฐ์ง ๋ฒค์น๋งํฌ ๋ฐ์ดํฐ์ ์์ ํ๊ฐํ์ต๋๋ค: CIFAR-10, CIFAR-100, TinyImageNet. ๋ ๊ฐ์ง ์๋๋ฆฌ์ค์์ ์คํํ์ต๋๋ค:
- Class-wise forgetting: ํน์ ํด๋์ค์ ๋ชจ๋ ์ํ ์ ๊ฑฐ
- Random data forgetting: ๋ฌด์์๋ก ์ํ๋ง๋ ํ๋ จ ์์ ์ ๊ฑฐ
์ฃผ์ ๋ฐ๊ฒฌ์ฌํญ:
- UAM์ ํด๋์ค๋ณ ์๊ธฐ์์ zero-forget๋ฅผ ๋ฌ์ฑํ์ฌ ์ ํํ ์ฌํ๋ จ๊ณผ ์ผ์นํฉ๋๋ค
- NG๋ ๋ฐ์ฐํ์ง๋ง, UAM์ ์์ ์ ์ผ๋ก ์ ์ง๋ฉ๋๋ค
- UAM์ ๊ฐ์ฅ ๋ฎ์ \(\Delta\)Acc.์ ๋ฌ์ฑํ์ฌ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ค๋๋ค.
LLM ์ธ๋ฌ๋์์์ ์ฑ๋ฅ
๋น์ ์์ ์ ๋์ด์, UAM์ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ(LLM)์ด ์ํํ ์ง์์ ์๋ ๋ฐ ์์ด์ ๋๋ผ์ด ํจ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค. Zephyr-7B-ฮฒ์ ๋ํด UAM์ WMDP-Bio์ WMDP-Cyber์์ ๊ฐ์ฅ ๋ฎ์ ์ํ ์ง์ ์ ์๋ฅผ ๋ฌ์ฑํ์ต๋๋ค.
์ธํ๋ฃจ์์ A๋ฅผ ๋ ์น๋ช ์ ์ผ๋ก ๋ง๋ค๊ธฐ ์ํ ์ํํ ์ง๋ฌธ์ ๋ํด ํ๋กฌํํธ๋ฅผ ๋ฐ์์ ๋, ๊ธฐ๋ณธ ๋ชจ๋ธ์ ์์ธํ ์ํํ ์ ๋ณด๋ฅผ ์ ๊ณตํฉ๋๋ค. UAM์ผ๋ก ์ธ๋ฌ๋ํ ํ, ๋ชจ๋ธ์ ๊ทธ๋ฌํ ์ํํ ์ฝํ ์ธ ์ ๊ณต์ ์ ์ ํ ๊ฑฐ๋ถํ์ฌ ๋ ์์ ํ ํ๋์ ๋ณด์ฅํฉ๋๋ค.
๊ฒฐ๋ก
๋จธ์ ์ธ๋ฌ๋์ ์ฌ์ฉ์ ํ๋ผ์ด๋ฒ์๋ฅผ ์กด์คํ๊ณ ๋ฐ์ดํฐ ๋ณดํธ ๊ท์ ์ ์ค์ํ๋ ์ ๋ขฐํ ์ ์๋ AI ์์คํ ๊ตฌ์ถ์ ํ์์ ์ ๋๋ค. ๋์ ์ ๋ชจ๋ธ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์ ํน์ ๋ฐ์ดํฐ์ ์ํฅ์ ํจ๊ณผ์ ์ผ๋ก ์ ๊ฑฐํ๋ ๊ฒ์ผ๋ก, ๊ธฐ์กด ๋ฐฉ๋ฒ๋ค์ด ์ด๋ ค์ํ๋ ์์ ์ ๋๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ Unlearning-Aware Minimization (UAM)์ ํจ๊ณผ์ ์ผ๋ก ์์ ๋ฐ์ดํฐ ์ ๊ฑฐํ๋ฉฐ, ์ด๋ฏธ์ง ๋ถ๋ฅ๋ถํฐ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ๊น์ง ํตํฉ์ ์ผ๋ก ์ ์ฉ์ด ๋ ์ ์์์ ๋ณด์์ต๋๋ค. ๋์๊ฐ, ๊ธฐ์ธ๊ธฐ ์ ๋ ฌ ๋ถ์ ๋ฑ์ ํตํด ํด๋น ๊ธฐ๋ฒ์ ์ํ์ ์๋ฆฌ๋ฅผ ๋ถ์ํ์ฌ ์ถํ ์ฐ๊ตฌ๋ก ์ด์ด์ง๋ ์ด๋ก ์ ํต์ฐฐ์ ์ ๊ณตํ์์ต๋๋ค.
๋จธ์ ์ธ๋ฌ๋ ๋ถ์ผ๋ ์ฌ์ ํ ์ด๋ก ์ ๋ณด์ฅ๋ถํฐ ๋ ํจ์จ์ ์ธ ์๊ณ ๋ฆฌ์ฆ๊น์ง ๋ง์ ์ด๋ฆฐ ๊ณผ์ ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. UAM์ด ๋์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์ ์งํํ๋ ๋ฐ์ดํฐ ์๊ตฌ์ฌํญ์ ์ ์ํ ์ ์๋ ํ๋ผ์ด๋ฒ์ ๋ณดํธ AI ์์คํ ๊ฐ๋ฐ์ ๊ธฐ์ฌํ๊ธฐ๋ฅผ ๋ฐ๋๋๋ค.
๊ตฌํ
๋ณธ ๋ ผ๋ฌธ์์๋ ๋จธ์ ์ธ๋ฌ๋(Machine Unlearning)์ ๋ฐ์ ์ ์ด๋ฐ์งํ๊ธฐ ์ํด ๋ค์์ ์คํ์์ค ํจํค์ง๋ฅผ ๊ณต๊ฐํ์์ต๋๋ค: machine-unlearning-pytorch [GitHub]
import torchunlearn
from torchunlearn.unlearn.trainers.uam import UAM
# ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ ๋ก๋
model = torchunlearn.utils.load_model(model_name="ResNet18", n_classes=10)
rmodel = torchunlearn.RobModel(model, n_classes=10)
# UAM ํธ๋ ์ด๋ ์ค์
trainer = UAM(rmodel, rho=0.01, gamma=2.0)
# ์ต์ ํ ๊ตฌ์ฑ
trainer.setup(
optimizer="SGD(lr=0.01, momentum=0.9, weight_decay=5e-4)",
scheduler=None,
n_epochs=5
)
# ์ธ๋ฌ๋์ผ๋ก ํ๋ จ
trainer.fit(
train_loaders=merged_loader, # ์ ์ง ๋ฐ ์์ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ ํฌํจ
n_epochs=5,
save_path="./models/unlearned",
save_best={"Clean(R)": "HB", "Clean(F)": "LBO"}
)