Source code for feste.prompt
import copyreg
import warnings
from pathlib import Path
from typing import Any, Optional, Union
import iso639
from cloudpickle import dumps, loads
from jinja2 import Environment, meta
from feste.task import FesteBase, feste_task
def _unpickle_Language(serialized) -> iso639.Language: # type: ignore
"""Unpickle iso639 class (to support cloudpickle)."""
part1 = loads(serialized)
return iso639.Language.from_part1(part1)
def _pickle_Language(cp: iso639.Language) -> Any:
"""Pickle iso639 class (to support cloudpickle)."""
serialized = dumps(cp.part1)
return _unpickle_Language, (serialized,)
# Register pickle/unpickle for iso639.Language
copyreg.pickle(iso639.Language, _pickle_Language)
# Internal Feste globals
FESTE_TEMPLATE_GLOBALS: dict[str, Any] = {
# TODO: add globals and utilities that are internal to Feste
}
[docs]class LanguageMismatch(UserWarning):
"""Exception when languages are mixed across prompts."""
pass
[docs]class FesteEnvironment(Environment):
"""This is the default Feste environment, it adds Feste's global
utilities into the Jinja2 environment.
"""
def __init__(self, **kwargs) -> None: # type: ignore[no-untyped-def]
super().__init__(**kwargs)
self.add_feste_globals()
[docs] def add_feste_globals(self) -> None:
self.globals.update(FESTE_TEMPLATE_GLOBALS)
[docs]class Prompt(FesteBase):
"""Prompt utility. This class represents a prompt and its
associated language and environment.
:param template: the prompt template (in Jinja2 format)
:param language: language code, defaults to en (follows ISO639)
:param environment: optional environment, defaults to Feste's env.
"""
def __init__(self, template: str,
language: Union[str, iso639.Language] = "en",
environment: Optional[Environment] = None) -> None:
self.environment = environment or FesteEnvironment()
if isinstance(language, str):
self.language_code = iso639.Language.match(language)
else:
self.language_code = language
self.template = template
def __add__(self, other: "Prompt") -> "Prompt":
"""Concatenate two different prompts and check if languages
match.
:param other: other Prompt to concatenate
"""
# Check if prompt languages are the same
if self.language != other.language:
warnings.warn("You are concatenating prompts with "
"different languages.", LanguageMismatch)
# Concatenate the templates
concat_template = self.template + other.template
new_prompt = Prompt(concat_template, self.language, self.environment)
return new_prompt
@property
def language(self) -> iso639.Language:
"""Returns the ISO639 language code of the prompt."""
return self.language_code
[docs] @classmethod
def from_file(cls, filename: Union[Path, str], **kwargs): # type: ignore
"""Loads the prompt from a text file.
:param filename: the filename or Python's native Path object.
:param kwargs: extra arguments being passed to the Prompt constructor.
"""
filename = Path(filename)
with filename.open("r") as fhandle:
template = fhandle.read()
return cls(template, **kwargs)
[docs] def variables(self) -> set[str]:
"""Return a list of variables present in the template.
:returns: set of variables.
"""
parsed_content = self.environment.parse(self.template)
tokens = meta.find_undeclared_variables(parsed_content)
return tokens
@feste_task
def __call__(self, **kwargs) -> str: # type: ignore
compiled_template = self.environment.from_string(self.template)
return compiled_template.render(**kwargs)
def __len__(self) -> int:
return len(self.template)
def __repr__(self) -> str:
return f"<Prompt Language='{self.language.name}' Length={len(self)}>"