pytorch에서 sampler 이용하여 data distirubtion 조절

가끔 딥러닝을 돌리다보면 label에 관한 편향된 distribution으로 인해 classification이 편향되어 나오는 걸 확인할 수 있다.

이를 예방하기 위해선 batch를 뽑을 때 distribution을 고려해서 뽑아줘야 한다.

이를 위해선 sampler를 사용해주면 되고 다음과 같이 이용가능하다.


1
2
3
sampler = WeightedRandomSampler(weights, 1, replacement=True# << create the sampler
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, sampler=sampler) 
cs

만약 n-gram 형태의 language model을 사용해서 text classification을 진행한다고 하면, label이 text length에 영향을 받는 경우가 있다. 그 경우 다음과 같이 dataset을 만들어주고 weight list를 만들어주면 된다.

주의점은 shuffle이 False여야 한다는 점.,


1
2
3
4
5
6
7
8
9
10
def create_dataset(tok_docs, vocab, n):
  n_grams = []
  document_ids = []
  weights = []  # << list of weights for sampling
  for i, doc in enumerate(tok_docs):
    for n_gram in [doc[0][i:i+n] for i in range(len(doc[0]) - 1)]:
       n_grams.append(n_gram)
       document_ids.append(i)
       weights.append(1/len(doc[0]))  # << ngrams of long documents are sampled less often
  return n_grams, document_ids, weights
cs

댓글

가장 많이 본 글