Source code for feste.backend.openai

from typing import Any, NamedTuple, Optional

import openai

from feste.optimization import BatchOptimization, Optimization
from feste.task import FesteBase, feste_task


[docs]class CompleteParams(NamedTuple): """Parameters for the OpenAI Complete API.""" model: str = "text-davinci-003" suffix: Optional[str] = None max_tokens: int = 16 temperature: float = 1.0 top_p: float = 1.0 n: int = 1 stream: bool = False logprobs: Optional[int] = None echo: bool = False stop: Optional[str] = None presence_penalty: float = 0.0 frequency_penalty: float = 0.0 best_of: int = 1 logit_bias: Optional[dict[str, int]] = None user: Optional[str] = None
[docs]class OpenAI(FesteBase): """This is the OpenAI API main class. :param api_key: the OpenAI API key :param organization: optional organization """ def __init__(self, api_key: str, organization: Optional[str] = None) -> None: super().__init__() self.set_api_key(api_key, organization) self.api_key = api_key self.organization = organization def _api_key_guard(self) -> None: """OpenAI Python client doesn't do proper encapsulation of API Keys, see: https://github.com/openai/openai-python/issues/233. Therefore, we need to set the API in each process before each call to make sure it is set in the object.""" if openai.api_key is None: self.set_api_key(self.api_key, self.organization)
[docs] @staticmethod def set_api_key(api_key: str, organization: Optional[str] = None) -> None: """Sets the API key and organization in the OpenAI module. :param api_key: the OpenAI API key :param organization: optional organization """ openai.organization = organization openai.api_key = api_key
[docs] @classmethod def optimizations(cls) -> list[Optimization]: """Optimizations implemented for OpenAI API.""" batch_optim = BatchOptimization({ cls.complete._obj: cls.complete_batch._obj, }) return [batch_optim,]
@staticmethod def _prepare_parameters(complete_params: CompleteParams) -> dict[str, Any]: all_params = complete_params._asdict() all_params = {k: v for k, v in all_params.items() if v is not None} return all_params @feste_task def complete(self, prompt: str, complete_params: CompleteParams = CompleteParams()) -> str: """This is the OpenAI official complete() API. :param prompt: input prompt text :param complete_params: the API parameters (e.g. temperature, etc) """ all_params = self._prepare_parameters(complete_params) all_params.update({"prompt": prompt}) self._api_key_guard() ret = openai.Completion.create(**all_params) text = str(ret.choices[0].text) return text @feste_task def complete_batch(self, prompt: list[str], complete_params: CompleteParams = CompleteParams()) \ -> list[str]: """This is the OpenAI official complete() API, but batched. :param prompt: input prompt text list :param complete_params: the API parameters (e.g. temperature, etc) """ all_params = self._prepare_parameters(complete_params) all_params.update({"prompt": prompt}) self._api_key_guard() ret = openai.Completion.create(**all_params) choices = [str(r.text) for r in ret.choices] return choices