Can only automatically infer lengths for datasets whose items are dictionaries with an '(self.model_input_name)' key.
Package:
transformers
50617

Exception Class:
ValueError
Raise code
ch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
)
lengths = [len(feature[self.model_input_name]) for feature in dataset]
self.lengths = lengths
self.generator = generator
def __len__(
Links to the raise (2)
https://github.com/huggingface/transformers/blob/bd9871657bb9500a9f4437a873db6df5f1ae6dbb/src/transformers/trainer_pt_utils.py#L523 https://github.com/huggingface/transformers/blob/bd9871657bb9500a9f4437a873db6df5f1ae6dbb/src/transformers/trainer_pt_utils.py#L588Ways to fix
LengthGroupedSampler is a sample that indices in a way that groups together features of the dataset of roughly the same length while keeping a bit of randomness.
Error code:
import torch
from transformers.trainer_pt_utils import LengthGroupedSampler
lengths = torch.randint(0, 25, (100,)).tolist()
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices = list(LengthGroupedSampler(lengths, 4)) #<--- didn't define our lenght
print(indices)
Because of length is None and input_name is not in the dataset (in our code lengths), an error is coming.
if lengths is None:
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or self.model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
)
Fix code:
import torch
from transformers.trainer_pt_utils import LengthGroupedSampler
lengths = torch.randint(0, 25, (100,)).tolist()
# Put one bigger than the others to check it ends up in first position
lengths[32] = 50
indices = list(LengthGroupedSampler(lengths, 4,lengths=lengths)) #<--- Added lengths
print(indices)
Add a possible fix
Please authorize to post fix