Skip to content

Maintaining state#

In our examples so far, we have passed data from step to step using properties of custom events. This is a powerful way to pass data around, but it has limitations. For example, if you want to pass data between steps that are not directly connected, you need to pass the data through all the steps in between. This can make your code harder to read and maintain.

To avoid this pitfall, we have a Context object available to every step in the workflow. To use it, declare an argument of type Context to your step. Here's how you do that.

We need one new import, the Context type:

from llama_index.core.workflow import (
    StartEvent,
    StopEvent,
    Workflow,
    step,
    Event,
    Context,
)

Now we define a start event that checks if data has been loaded into the context. If not, it returns a SetupEvent which triggers setup that loads the data and loops back to start.

class SetupEvent(Event):
    query: str


class StepTwoEvent(Event):
    query: str


class StatefulFlow(Workflow):
    @step
    async def start(
        self, ctx: Context, ev: StartEvent
    ) -> SetupEvent | StepTwoEvent:
        db = await ctx.store.get("some_database", default=None)
        if db is None:
            print("Need to load data")
            return SetupEvent(query=ev.query)

        # do something with the query
        return StepTwoEvent(query=ev.query)

    @step
    async def setup(self, ctx: Context, ev: SetupEvent) -> StartEvent:
        # load data
        await ctx.store.set("some_database", [1, 2, 3])
        return StartEvent(query=ev.query)

Then in step_two we can access data directly from the context without having it passed explicitly. In gen AI applications this is useful for loading indexes and other large data operations.

@step
async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
    # do something with the data
    print("Data is ", await ctx.store.get("some_database"))

    return StopEvent(result=await ctx.store.get("some_database"))


w = StatefulFlow(timeout=10, verbose=False)
result = await w.run(query="Some query")
print(result)

Adding Typed State#

Often, you'll have some preset shape that you want to use as the state for your workflow. The best way to do this is to use a Pydantic model to define the state. This way, you:

  • Get type hints for your state
  • Get automatic validation of your state
  • (Optionally) Have full control over the serialization and deserialization of your state using validators and serializers

NOTE: You should use a pydantic model that has defaults for all fields. This enables the Context object to automatically initialize the state with the defaults.

Here's a quick example of how you can leverage workflows + pydantic to take advantage of all these features:

from pydantic import BaseModel, Field, field_validator, field_serializer
from typing import Union

from llama_index.core.workflow import (
    Context,
    Workflow,
    StartEvent,
    StopEvent,
    step,
)


# This is a random object that we want to use in our state
class MyRandomObject:
    def __init__(self, name: str = "default"):
        self.name = name


# This is our state model
# NOTE: all fields must have defaults
class MyState(BaseModel):
    my_obj: MyRandomObject = Field(default_factory=MyRandomObject)
    some_key: str = Field(default="some_value")

    # This is optional, but can be useful if you want to control the serialization of your state!

    @field_serializer("my_obj", when_used="always")
    def serialize_my_obj(self, my_obj: MyRandomObject) -> str:
        return my_obj.name

    @field_validator("my_obj", mode="before")
    @classmethod
    def deserialize_my_obj(
        cls, v: Union[str, MyRandomObject]
    ) -> MyRandomObject:
        if isinstance(v, MyRandomObject):
            return v
        if isinstance(v, str):
            return MyRandomObject(v)

        raise ValueError(f"Invalid type for my_obj: {type(v)}")


class MyStatefulFlow(Workflow):
    @step
    async def start(self, ctx: Context[MyState], ev: StartEvent) -> StopEvent:
        # Returns MyState directly
        state = await ctx.store.get_state()
        state.my_obj.name = "new_name"
        await ctx.store.set_state(state)

        # Can also access fields directly if needed
        name = await ctx.store.get("my_obj.name")
        await ctx.store.set("my_obj.name", "newer_name")

        return StopEvent(result="Done!")


w = MyStatefulFlow(timeout=10, verbose=False)

ctx = Context(w)
result = await w.run(ctx=ctx)
state = await ctx.store.get_state()
print(state)

Up next we'll learn how to stream events from an in-progress workflow.