Source code for feste.backend.cohere

from concurrent.futures import Executor, Future
from typing import NamedTuple, Optional

import cohere

from feste.task import FesteBase, feste_task


[docs]class GenerateParams(NamedTuple): """Parameters for the Cohere generate API.""" prompt_vars: object = {} model: Optional[str] = "xlarge" preset: Optional[str] = None num_generations: Optional[int] = None max_tokens: Optional[int] = None temperature: Optional[float] = None k: Optional[int] = None p: Optional[float] = None frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None end_sequences: Optional[list[str]] = None stop_sequences: Optional[list[str]] = None return_likelihoods: Optional[str] = None truncate: Optional[str] = None logit_bias: dict[int, float] = {}
[docs]class DummyExecutor(Executor):
[docs] def submit(self, fn, *args, **kwargs) -> Future: # type: ignore f: Future = Future() try: result = fn(*args, **kwargs) except BaseException as e: f.set_exception(e) else: f.set_result(result) return f
[docs]class Cohere(FesteBase): """This is the Cohere API main class. .. note:: Note that the Cohere API uses an internal thread pool to do calls. This internal pool is replaced by a dummy one in Feste's implementation because we are already parallelizing the calls from outside of Cohere API implementation. :param api_key: the Cohere API key :param client_name: optional client name :param check_api_key: if API key should be checked (offline) :param max_retries: default number of retries """ def __init__(self, api_key: str, client_name: Optional[str] = None, check_api_key: bool = True, max_retries: int = 3) -> None: super().__init__() self.client = cohere.Client(api_key=api_key, num_workers=1, check_api_key=check_api_key, max_retries=max_retries, client_name=client_name) # Cohere API uses a thread pool internally, we don't need it as we # are already paralelizing the calls. So we just replace the # executor here with a dummy serial one. self.client._executor = DummyExecutor() @feste_task def generate(self, prompt: str, complete_params: GenerateParams = GenerateParams()) -> str: """This is the Cohere official generate() API. :param prompt: input prompt text :param complete_params: the API parameters (e.g. temperature, etc) """ all_params = complete_params._asdict() all_params.update({"prompt": prompt}) ret = self.client.generate(**all_params) return str(ret.generations[0].text)