Monday, September 11, 2023

Mocking async openai package calls with pytest

As part of my role the Python advocacy team for Azure, I am now one of the maintainers on several ChatGPT samples, like my simple chat app and the very popular chat + search app. Both of those samples use Quart, the asynchronous version of Flask, which enables them to use the asynchronous versions of the functions from the openai package.

Making async openai calls

A synchronous call to the streaming ChatCompletion API looks like:

response = openai.ChatCompletion.create(
  messages=[{"role": "system", "content": "You are a helpful assistant."},	
            {"role": "user", "content": request_message}],	
  stream=True)	
An asynchronous call to that same API looks like:
response = await openai.ChatCompletion.acreate(
  messages=[{"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": request_message},],
  stream=True)

The difference is just the addition of await to wait for the results of the asynchronous function (and signal that the process can work on other tasks), along with the change in method name from create to acreate. That's a small difference in our app code, but it's a significant difference when it comes to mocking those calls, so it's worth pointing out.

Mocking a streaming call

In our tests of the apps, we don't want to actually make calls to the OpenAI servers, since that'd require authentication and would use up quota needlessly. Instead, we can mock the calls using the built-in pytest fixture monkeypatch with code that mimics the openai package's response.

Here's the fixture that I use to mock the asynchronous acreate call:

@pytest.fixture
def mock_openai_chatcompletion(monkeypatch):

    class AsyncChatCompletionIterator:
        def __init__(self, answer: str):
            self.answer_index = 0
            self.answer_deltas = answer.split(" ")

        def __aiter__(self):
            return self

        async def __anext__(self):
            if self.answer_index < len(self.answer_deltas):
                answer_chunk = self.answer_deltas[self.answer_index]
                self.answer_index += 1
                return openai.util.convert_to_openai_object(
                    {"choices": [{"delta": {"content": answer_chunk}}]})
            else:
                raise StopAsyncIteration

    async def mock_acreate(*args, **kwargs):
        return AsyncChatCompletionIterator("The capital of France is Paris.")

    monkeypatch.setattr(openai.ChatCompletion, "acreate", mock_acreate)

The final line of that fixture swaps the acreate method with my mock method that returns a class that acts like an asynchronous generator thanks to its __anext__ dunder method. That method returns a chunk of the answer each time it's called, until there are no chunks left.

Mocking non-streaming call

For the other repo, which supports both streaming and non-streaming response, the mock acreate method must account for the non-streaming case by immediately returning the full answer.

    async def mock_acreate(*args, **kwargs):
        messages = kwargs["messages"]
        answer = "The capital of France is Paris."
        if "stream" in kwargs and kwargs["stream"] is True:
            return AsyncChatCompletionIterator(answer)
        else:
            return openai.util.convert_to_openai_object(
                {"choices": [{"message": {"content": answer}}]})

Mocking multiple answers

If necessary, it's possible to make the mock respond with different answers based off the passed on the last message passed in. We need that for the chat + search app, since we also use a ChatGPT call to generate keyword searches based on the user question.

Just change the answer based off the messages keyword arg:

    async def mock_acreate(*args, **kwargs):
        messages = kwargs["messages"]
        if messages[-1]["content"] == "Generate search query for: What is the capital of France?":
            answer = "capital of France"
        else:
            answer = "The capital of France is Paris."

Mocking other openai calls

We also make other calls through the openai package, like to create embeddings. That's a much simpler mock, since there's no streaming involved:

@pytest.fixture
def mock_openai_embedding(monkeypatch):
    async def mock_acreate(*args, **kwargs):
        return {"data": [{"embedding": [0.1, 0.2, 0.3]}]}

    monkeypatch.setattr(openai.Embedding, "acreate", mock_acreate)

More resources

For more context and example tests, view the full tests in the repos:

No comments: