top of page

[딥러닝] Gumbel Softmax

  • 작성자 사진: Minwu Kim
    Minwu Kim
  • 2024년 9월 12일
  • 2분 분량

최종 수정일: 2024년 9월 17일

하나. 문제의식

VAE 같은 생성형 모델에선 종종 샘플링 절차가 필요하다. 즉, 특정 분포에서 하나의 값을 랜덤으로 추출하는 것이다. 이 경우 문제가 하나 생기는데, 바로 backpropagation이 불가하다는 점이다. 그래서 이러한 discrete sampling을 differentiable한 방식으로 바꿔주는 reparametrization trick이 필요한다. 그 방법 중 하나가 바로 Gumbel Softmax이다.



둘. Gumbel Max

Gumbel Softmax를 알기 전해 Gumbel Max에 대해 짚고 넘어가야 한다. Gumbel Max는 이산분포 샘플링을 argmax 형식으로 바꾸는 reparameterization trick이다.

ree

위와 같은 이산분포를 따르는 random variable z가 있다고 해보자. 그리고 Gumbel distribution의 모양새도 위와 같다. 여기서 Gumbel distribution의 특성이 뭔지는 몰라도 무방하다. 그냥 일종의 확률분포이다.


ree

Gumbel-max trick은 위와 같다. 좌측의 해당 argmax 수식이 우측의 discrete sampling과 equivalent하다는 것이다. 이에 대해 일련의 수학적 증명이 있는데, 일단은 그냥 넘어가도록 한다.



셋. Gumbel Softmax

ree

Gumbel softmax는 Gumbel-max처럼 argmax를 취하는 대신 softmax를 취하는 방식이다. 이를 통해 discrete한 수식을 differentiable하게 바꾸는 것이다. 그 외에도 tau라는 변수가 있는데, 이는 하이퍼파라미터로서, 일종의 temperature와 같다고 보면 될 것 같다. Tau 값이 높을수록 분포가 smoothen 된다.



넷. Straight-through Gumbel Softmax

간단하다. Feedforward pass시는 Gumbel max를 사용해 샘플링하고, backpropagation시에는 gumbel softmax를 사용해 미분하는 것이다.


다섯. Why stochastic?

여기까지 보면서 든 생각은, 왜 그냥 softmax를 취하지 않느냐는 것이다. Gumbel softmax는 랜덤 샘플링의 reparameterization으로서, stochasticity를 품고있다. 하지만 만약 RAG처럼 그냥 greedy하게 top-k를 뽑는 경우라면, 이런 stochasticity가 필요 없는 것 아닌가 싶었다.

하지만 잘 생각해보면 문제가 있다. 바로 exploration이다. 네트워크 트레이닝 시 exploration 없이 greedy하게만 한다면 높은 확률로 saddle point에 갇히거나, 아예 학습이 불가능할 수도 있다. 고로 학습시 이러한 랜덤성은 꼭 포함해줘야 한다. 다만 학습 완료 후 추론을 할 때는 그냥 deterministic한 argmax를 취해도 될 듯 하다.


여섯. Retriever에 적용

첫번째 옵션:

  • Input: Concatenate the query vector and document vectors.

  • Neural Network: Trainable neural network processes these inputs to output relevance scores.

  • Gumbel-Softmax: Apply Gumbel-Softmax to make the sampling differentiable and allow for end-to-end optimization.

두번째 옵션:

  • Input: The query vector is fed into a cross-attention mechanism where it attends over all the document vectors. This allows the model to weigh different parts of the documents depending on the query.

  • Neural Network (Attention-Based): The attention mechanism would output attention scores for each document, which indicate the relevance of each document in relation to the query.

  • Gumbel-Softmax: Apply the Gumbel-Softmax trick to the attention scores to select documents stochastically.

 
 
 

댓글


bottom of page