diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index f77065e8..b257c7db 100755 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -37,7 +37,9 @@ def get_context(self, doc, num_fewshot): + ( str(self.doc_to_target(doc)[0]) if type(self.doc_to_target(doc)) is list - else self.doc_to_target(doc) if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) + else self.doc_to_target(doc) + if (self.config.doc_to_choice is None or type(self.doc_to_target(doc)) is str) + else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]) ) for doc in selected_docs ] @@ -91,4 +93,4 @@ def get_sampler(name): try: return SAMPLER_REGISTRY[name] except KeyError: - raise ValueError(f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}") + raise ValueError(f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}") \ No newline at end of file