AI/공부

[공부] Focal Loss란 무엇인가?

대학원생의 그저 그런 기록 2025. 6. 25. 14:38

연구를 하다가 데이터 불균형에 적용할 수 있는 방법들을 고민해보았다. 그러다 Focal Loss 방법을 알게되었으며, Focal Loss는 이미지 데이터를 기반으로 하는 것이기에 난 텍스트 데이터를 다루고 있어 주제가 조금 다를 수 있지만, 그래도 적용하는데에 큰 문제는 없기에 써보았다. 그러면서 어떤 효과가 있는지 등을 알고 싶었고 이에 대해 좀 작성해보려고 한다. 

참고 논문
https://arxiv.org/abs/1708.02002

 

Focal Loss란 무엇인가?


Focal loss는 RetinaNet 논문에서 제안된 손실 함수로, Object Detection Model에서 발생하는 극심한 클래스 불균형(쉽게 맞히는 것 vs 쉽게 맞히지 못하는 소수 객체) 관련해 Easy negative가 전체 손실을 지배해 모델이 학습되지 않는 문제를 완화하기 위해 고안된 함수이다. 기존 Cross Entropy(CE)에서 잘 분류된 샘플의 억제해 hard example에 학습을 집중하는 것이 핵심 아이디어입니다. 또한, 과정은 다음과 같이 생각하면 됩니다. 

  1. 순수 Cross Entropy : 하이퍼파라미터 없이 모든 샘플을 똑같이 학습.
  2. Cross Entropy + weight : 클래스별 고정 가중치로 희귀 클래스 손실 비중 ↑.
  3. Focal Loss : weight로 빈도 보정 + (1−p)γ(1-p)^{\gamma}로 쉬운 샘플 손실 ↓, hard 샘플 집중

 

$$  \text{Cross-Entropy (CE)} : -\text{log}(p_t)$$

  • 파라미터가 없으면 순수 Cross Entropy
  • 파라미터를 제공하면 클래스별 고정 가중치(class-balancing)

 

$$ \text{ Focal Loss (FL)} : -\alpha_t(1-p_t)^{\gamma} \text{log}(p_t) $$

파라미터는 각각 조건은 다음과 같습니다.

  • $\gamma ≥ 0 $ (focusing)
  • $ \alpha_t ∈ (0,1] $ (class-balancing)
  • $ \gamma = 0,  \alpha_t = 1$인 경우는 CE와 동일합니다.

 

Focal Loss 하이퍼 파라미터의 역할


여기서 Focal loss의 하이퍼 파라미터가 값이 커지거나 작아지면서 어떤 역할을 하는지 알아보려고 합니다. 위에서 설명한데로 하이퍼파라미터는 2가지가 있고 하나씩 설명해보겠습니다.

$\gamma$의 역할

  • $\gamma=0$이면 CE와 동일
  • $\gamma$가 커질수록 잘 맞힌 샘플($p_t$ → 1)의 손실이 빠르게 0에 수렴 → 그 대신 어려운 샘플에 가중치 부여

$\alpha$의 역할

  • $\alpha$ 클래스 불균형이 심할 때 희소 클래스에 더 큰 값을 부여해 균형 조정 

논문에서는 $\alpha, \gamma$를 각각 0.25와 2를 기본으로 했습니다.

직관적으로 샘플 난이도에 기반해 동적 가중치를 부여해 손실 함수 내부에서 자연스럽게 수행합니다.

 

 

내 데이터에서는 어떻게 설정하면 좋을까?


기본적으로 데이터의 수가 각각 양성 데이터 300개와 음성 데이터 100개가 있다고 해봅시다.

  • $\alpha$ : 데이터의 수에 대한 비율에 맞춰 0.3과 0.1을 제공하여 loss의 비율을 맞출 수 있습니다. 
  • $\gamma$ : 이것은 optuna와 같은 라이브러리를 이용해 가장 최적의 값을 찾아야 합니다.

코드로 하면 다음과 같습니다.

 

아래 코드는 양성 300 : 음성 100 데이터 분포를 감안해

  • class-weight를 [1.0, 3.0] (음성이 3배 희귀)
  • **class-wise γ(focusing)**를 [0.3, 0.1] (다수 클래스인 양성에 더 큰 γ 적용)

으로 설정한 세 가지 손실 함수를 보여 줍니다.

import torch
import torch.nn as nn
import torch.nn.functional as F

# 공통 설정 -------------------------------------------------------------
class_weights = torch.tensor([1.0, 3.0])   # [positive, negative]
gamma_vector  = torch.tensor([0.3, 0.1])   # class-wise γ

# 1️⃣ 순수 Cross-Entropy -------------------------------------------------
def ce_loss(logits, targets):
    """모든 샘플을 동일 가중치로 학습"""
    return F.cross_entropy(logits, targets)

# 2️⃣ Cross-Entropy + class-weight --------------------------------------
def ce_weighted_loss(logits, targets):
    """빈도 불균형 보정(음성 클래스 손실 ×3)"""
    return F.cross_entropy(logits, targets, weight=class_weights)

# 3️⃣ Focal Loss (class-weight + class-wise γ) ---------------------------
class FocalLoss(nn.Module):
    def __init__(self, alpha=class_weights, gamma=gamma_vector, reduction="mean"):
        super().__init__()
        self.alpha  = alpha.float()
        self.gamma  = gamma.float()
        self.reduction = reduction

    def forward(self, logits, targets):
        logp   = F.log_softmax(logits, dim=-1)          # (N, 2)
        p      = logp.exp()
        y_one  = F.one_hot(targets, 2).float()          # (N, 2)

        p_t    = (p * y_one).sum(dim=1)                 # 정답 확률
        alpha  = self.alpha[targets]                    # 클래스별 α
        gamma  = self.gamma[targets]                    # 클래스별 γ

        loss   = -alpha * (1 - p_t).pow(gamma) * torch.log(p_t + 1e-12)
        return loss.mean() if self.reduction == "mean" else loss.sum()

 

  • 순수 CE : 불균형을 전혀 고려하지 않음.
  • 가중 CE : 음성 손실을 3배 키워 빈도 불균형만 보정.
  • Focal Loss : 위 가중치에 더해 γ 벡터로 다수 클래스(양성)의 easy 샘플 손실을 추가로 억제해 ‘희귀 & 어려운’ 음성 샘플에 학습을 집중

 

결론


 

  • 데이터가 대체로 균형일 때
    • 두 클래스(또는 여러 클래스) 빈도가 비슷해 모델이 어느 한쪽에 치우칠 염려가 거의 없다면, 순수 Cross Entropy가 가장 간단하고 안정적입니다. 파라미터를 따로 조정할 필요도 없고, 수렴도 빠르며 확률 해석도 직관적입니다.
  • 불균형이 눈에 띄지만 hard / easy 난이도 차이가 크게 문제 되지 않을 때
    • 예를 들어 한쪽 클래스가 5 ~ 10배 정도 적은 상황에서는 Cross Entropy + 클래스 가중치(weight)가 효과적입니다.
    • 손실 계산 시 희귀 클래스에 고정 스케일을 곱해 주기만 하면 되므로 구현이 쉽고, 소수 클래스의 손실 비중이 늘어나 Recall·F1이 눈에 띄게 개선됩니다.
    • 다만 여기서는 여전히 “쉽게 맞히는 샘플”과 “어려운 샘플”을 동일하게 다룬다는 한계가 있습니다.
  • 극단적 불균형이면서 ‘easy negative’가 데이터의 대다수를 차지할 때
    • 객체 탐지의 배경 vs 객체처럼 다수 클래스가 너무 쉽거나, 멀티라벨에서 드문 라벨이 거의 안 나오는 상황이라면 Focal Loss가 필요합니다.
    • 이 손실은 클래스 가중치(α)로 빈도 편향을 먼저 보정한 뒤, 잘 맞힌(easy) 샘플의 손실을 기하급수적으로 줄여 하드 샘플(희귀 클래스 + 낮은 확신)에 학습을 집중시킵니다.
    • γ 값을 조절해 “얼마나 집중할지”를 결정할 수 있어, 소수·어려운 샘플의 재현율을 극대화하고 다수 클래스 과적합을 방지하는 데 특히 강력합니다.