21 lines
697 B
Python
21 lines
697 B
Python
import torch
|
|
|
|
from timm.data import Mixup
|
|
from timm.data.mixup import mixup_target
|
|
|
|
|
|
class TimmMixup(Mixup):
|
|
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even.
|
|
"""
|
|
def __call__(self, x, target):
|
|
if self.mode == 'elem':
|
|
lam = self._mix_elem(x)
|
|
elif self.mode == 'pair':
|
|
# We move the assert from the beginning of the function to here
|
|
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
|
lam = self._mix_pair(x)
|
|
else:
|
|
lam = self._mix_batch(x)
|
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
|
|
return x, target
|