연구를 하다가 데이터 불균형에 적용할 수 있는 방법들을 고민해보았다. 그러다 Focal Loss 방법을 알게되었으며, Focal Loss는 이미지 데이터를 기반으로 하는 것이기에 난 텍스트 데이터를 다루고 있어 주제가 조금 다를 수 있지만, 그래도 적용하는데에 큰 문제는 없기에 써보았다. 그러면서 어떤 효과가 있는지 등을 알고 싶었고 이에 대해 좀 작성해보려고 한다.
참고 논문
https://arxiv.org/abs/1708.02002
Focal Loss란 무엇인가?
Focal loss는 RetinaNet 논문에서 제안된 손실 함수로, Object Detection Model에서 발생하는 극심한 클래스 불균형(쉽게 맞히는 것 vs 쉽게 맞히지 못하는 소수 객체) 관련해 Easy negative가 전체 손실을 지배해 모델이 학습되지 않는 문제를 완화하기 위해 고안된 함수이다. 기존 Cross Entropy(CE)에서 잘 분류된 샘플의 억제해 hard example에 학습을 집중하는 것이 핵심 아이디어입니다. 또한, 과정은 다음과 같이 생각하면 됩니다.
- 순수 Cross Entropy : 하이퍼파라미터 없이 모든 샘플을 똑같이 학습.
- Cross Entropy + weight : 클래스별 고정 가중치로 희귀 클래스 손실 비중 ↑.
- Focal Loss : weight로 빈도 보정 + (1−p)γ(1-p)^{\gamma} 로 쉬운 샘플 손실 ↓, hard 샘플 집중
$$ \text{Cross-Entropy (CE)} : -\text{log}(p_t)$$
- 파라미터가 없으면 순수 Cross Entropy
- 파라미터를 제공하면 클래스별 고정 가중치(class-balancing)
$$ \text{ Focal Loss (FL)} : -\alpha_t(1-p_t)^{\gamma} \text{log}(p_t) $$
파라미터는 각각 조건은 다음과 같습니다.
- $\gamma ≥ 0 $ (focusing)
- $ \alpha_t ∈ (0,1] $ (class-balancing)
- $ \gamma = 0, \alpha_t = 1$인 경우는 CE와 동일합니다.
Focal Loss 하이퍼 파라미터의 역할
여기서 Focal loss의 하이퍼 파라미터가 값이 커지거나 작아지면서 어떤 역할을 하는지 알아보려고 합니다. 위에서 설명한데로 하이퍼파라미터는 2가지가 있고 하나씩 설명해보겠습니다.
$\gamma$의 역할
- $\gamma=0$이면 CE와 동일
- $\gamma$가 커질수록 잘 맞힌 샘플($p_t$ → 1)의 손실이 빠르게 0에 수렴 → 그 대신 어려운 샘플에 가중치 부여
$\alpha$의 역할
- $\alpha$ 클래스 불균형이 심할 때 희소 클래스에 더 큰 값을 부여해 균형 조정
논문에서는 $\alpha, \gamma$를 각각 0.25와 2를 기본으로 했습니다.
직관적으로 샘플 난이도에 기반해 동적 가중치를 부여해 손실 함수 내부에서 자연스럽게 수행합니다.
내 데이터에서는 어떻게 설정하면 좋을까?
기본적으로 데이터의 수가 각각 양성 데이터 300개와 음성 데이터 100개가 있다고 해봅시다.
- $\alpha$ : 데이터의 수에 대한 비율에 맞춰 0.3과 0.1을 제공하여 loss의 비율을 맞출 수 있습니다.
- $\gamma$ : 이것은 optuna와 같은 라이브러리를 이용해 가장 최적의 값을 찾아야 합니다.
코드로 하면 다음과 같습니다.
아래 코드는 양성 300 : 음성 100 데이터 분포를 감안해
- class-weight를 [1.0, 3.0] (음성이 3배 희귀)
- **class-wise γ(focusing)**를 [0.3, 0.1] (다수 클래스인 양성에 더 큰 γ 적용)
으로 설정한 세 가지 손실 함수를 보여 줍니다.
import torch
import torch.nn as nn
import torch.nn.functional as F
# 공통 설정 -------------------------------------------------------------
class_weights = torch.tensor([1.0, 3.0]) # [positive, negative]
gamma_vector = torch.tensor([0.3, 0.1]) # class-wise γ
# 1️⃣ 순수 Cross-Entropy -------------------------------------------------
def ce_loss(logits, targets):
"""모든 샘플을 동일 가중치로 학습"""
return F.cross_entropy(logits, targets)
# 2️⃣ Cross-Entropy + class-weight --------------------------------------
def ce_weighted_loss(logits, targets):
"""빈도 불균형 보정(음성 클래스 손실 ×3)"""
return F.cross_entropy(logits, targets, weight=class_weights)
# 3️⃣ Focal Loss (class-weight + class-wise γ) ---------------------------
class FocalLoss(nn.Module):
def __init__(self, alpha=class_weights, gamma=gamma_vector, reduction="mean"):
super().__init__()
self.alpha = alpha.float()
self.gamma = gamma.float()
self.reduction = reduction
def forward(self, logits, targets):
logp = F.log_softmax(logits, dim=-1) # (N, 2)
p = logp.exp()
y_one = F.one_hot(targets, 2).float() # (N, 2)
p_t = (p * y_one).sum(dim=1) # 정답 확률
alpha = self.alpha[targets] # 클래스별 α
gamma = self.gamma[targets] # 클래스별 γ
loss = -alpha * (1 - p_t).pow(gamma) * torch.log(p_t + 1e-12)
return loss.mean() if self.reduction == "mean" else loss.sum()
- 순수 CE : 불균형을 전혀 고려하지 않음.
- 가중 CE : 음성 손실을 3배 키워 빈도 불균형만 보정.
- Focal Loss : 위 가중치에 더해 γ 벡터로 다수 클래스(양성)의 easy 샘플 손실을 추가로 억제해 ‘희귀 & 어려운’ 음성 샘플에 학습을 집중
결론
- 데이터가 대체로 균형일 때
- 두 클래스(또는 여러 클래스) 빈도가 비슷해 모델이 어느 한쪽에 치우칠 염려가 거의 없다면, 순수 Cross Entropy가 가장 간단하고 안정적입니다. 파라미터를 따로 조정할 필요도 없고, 수렴도 빠르며 확률 해석도 직관적입니다.
- 불균형이 눈에 띄지만 hard / easy 난이도 차이가 크게 문제 되지 않을 때
- 예를 들어 한쪽 클래스가 5 ~ 10배 정도 적은 상황에서는 Cross Entropy + 클래스 가중치(weight)가 효과적입니다.
- 손실 계산 시 희귀 클래스에 고정 스케일을 곱해 주기만 하면 되므로 구현이 쉽고, 소수 클래스의 손실 비중이 늘어나 Recall·F1이 눈에 띄게 개선됩니다.
- 다만 여기서는 여전히 “쉽게 맞히는 샘플”과 “어려운 샘플”을 동일하게 다룬다는 한계가 있습니다.
- 극단적 불균형이면서 ‘easy negative’가 데이터의 대다수를 차지할 때
- 객체 탐지의 배경 vs 객체처럼 다수 클래스가 너무 쉽거나, 멀티라벨에서 드문 라벨이 거의 안 나오는 상황이라면 Focal Loss가 필요합니다.
- 이 손실은 클래스 가중치(α)로 빈도 편향을 먼저 보정한 뒤, 잘 맞힌(easy) 샘플의 손실을 기하급수적으로 줄여 하드 샘플(희귀 클래스 + 낮은 확신)에 학습을 집중시킵니다.
- γ 값을 조절해 “얼마나 집중할지”를 결정할 수 있어, 소수·어려운 샘플의 재현율을 극대화하고 다수 클래스 과적합을 방지하는 데 특히 강력합니다.