Prefect Logo
Prefect Product

Break It, Fix It, Reverse It: Transactional ML Pipelines with Prefect 3.0

September 30, 2024
Jeff Hale
Head of Developer Education
Share

In the world of data engineering, failure isn't just a possibility—it's an inevitability. As Werner Vogels, AWS CTO, famously said: "Everything fails all the time." Engineers who build and rely on data pipelines know this quote is all too true.

Imagine this scenario: you’re presenting a dashboard when suddenly, the numbers in your dashboard plummet by 50%. As you investigate, you discover an upstream data pipeline failed midway through its execution. Sound familiar?

This is where the power of transactional semantics in data pipelines becomes crucial. Resilient pipelines should maintain a clean, correct output state even when failures occur, ensuring data integrity and preserving trust with your stakeholders. Transactional capabilities allow your pipelines to gracefully roll back when partial operations fail, returning your workspace to a known good state instead of leaving you with corrupted data or half-baked models.

In this post, we'll explore how to leverage Prefect 3.0's new transaction semantics and rollback hooks, combined with DVC (Data Version Control), to build ML training workflows you can truly trust. You'll learn how to:

  • Implement automatic rollbacks that undo problematic changes
  • Maintain data and model version consistency even during failures
  • Enhance pipeline resilience without sacrificing development speed

By the end, you'll have the tools to create ML pipelines that don't just break—they break, fix themselves, and reverse to a safe state, all automatically. Let's dive in and see how Prefect 3.0 is revolutionizing the way we handle failure in data orchestration.

Transactional rollbacks and versioning with Prefect and DVC

Prefect: Orchestrating workflow with confidence

Prefect is an open source Python package that gives data and ML engineers the tools they need to build resilient workflows. Prefect provides teams with the tools for modern workflow orchestration, all the way from script to scale.

Prefect recently released version 3.0 for general availability, highlights include: transactional semantics (the focus of this post), a new events backend powering event-driven workflows, automations for observability, improved runtime performance, and autonomous task execution.

Here, we’ll focus on an enhancement that can make your workflows more resilient to failure - transactions. Transactions help you recover from failure by returning your workspace to a desired point in the past. We’ll also show how transactional semantics can help you save time and money through caching, so that you don’t run workflows unnecessarily.

DVC: Version control for data and ML models

Data Version Control (DVC) complements Prefect as another open source Python package. DVC is especially helpful for versioning large datasets and machine learning models that don't play well with traditional version control systems like Git.

We’ll use git and DVC to version our dataset, model, and model artifacts. DVC uses git semantics (commit, checkout, etc.) to help you keep track of everything.

Model overview

This example is adapted from DVC's Data and Model Versioning tutorial, which was based on a Keras blog post. We’ll train an image classification model to detect whether an image is of a cat or a dog. We’ll modify the script with Prefect for enhanced observability, resiliency, and orchestration. Importantly, we’ll add a check to see if our second run of model training reaches the threshold we need to use the new model. If it doesn’t make the cut, Prefect 3.0’s transaction semantics will allow us to easily roll back our model and data to a previous version.

We’ll download a few thousand images and train the model. With a fast connection and most computers, the tutorial script will only take a few minutes to run.

Setup and implementation

Step 1: clone the repo

Clone my repo from GitHub and move into it.

1git clone https://github.com/discdiver/prefect-transactions-dvc.git
2cd prefect-transactions-dvc

Step 2: virtual environment setup

We recommend using a Python virtual environment to run this example. To create a virtual environment with venv and install the required packages, run the following commands:

1python -m venv .env
2source .env/bin/activate
3pip install -r requirements.txt

Note that I’m using Python 3.12 in this guide, but other recent versions should work fine, too.

Step 3: set up a Prefect API

This example assumes you have a Prefect 3.0 server instance running or have a Prefect Cloud account with your CLI authenticated. See the Quickstart instructions, if needed.

The original DVC tutorial contains a number of CLI commands. I’ve modified the code so that you don’t need to run any commands once you’re cloned the repo, other than a single command to run the train-prefect.py script.

Run a Python script with retries and rollback

Let’s run the script and inspect the code while it’s running.

Step 4: run the script

Navigate to the example-versioning folder and run:

1python train-prefect.py

While the model trains, let’s examine a few components, starting with the train-prefect.py imports which include these Prefect imports:

1from prefect import flow, task
2from prefect.transactions import get_transaction, transaction

➡️ Prefect task decorator: The Prefect@task decorator is used to denote a discrete unit of work in a Prefect workflow. You can turn any Python function into a task by adding an @task decorator to it.

1@task(retries=3)
2def add_data(dataset_name: str = "data"):
3    """Fetch and add data for model training"""
4    subprocess.run(
5        split(
6            f"dvc get https://github.com/iterative/dataset-registry \
7            tutorials/versioning/{dataset_name}.zip"
8        )
9    )
10    ...

➡️ Automatic retries: We’ve added the retries=3 to the task decorator on the first task that downloads 1,000 images because we want to retry this task three times if it fails to fetch the data. Retries are one way Prefect makes your code more resilient and handling intermittent failures.

The next task prepares the images, loads the model, and fine tunes a pre-trained VGG16 model using Tensorflow Keras.

1@task(log_prints=True)
2def train_model():
3    """Train model for image classification"""
4		...

➡️ Printing logs: Specifying log_prints=True allows us to easily log any print statements - just one way that Prefect helps you observe the state of your workflows and recover from any issues. There’s a lot going on in this task, but the take-away is that a trained model’s weights are saved and the training history is written into a file named metrics.csv.

The next function uses Git and DVC to track the data, model, and related artifacts (these commands are all the same as in the original DVC tutorial).

1@task
2def git_track(tag: str = "v1.0", img_count: int = 1000):
3    """Track model changes in git"""
4    subprocess.run(split("dvc add model.weights.h5"))
5    subprocess.run(
6        split(f"git add data.dvc model.weights.h5.dvc metrics.csv .gitignore")
7    )
8    subprocess.run(
9        split(f"git commit -m '{tag} model, trained with {img_count} images'")
10    )
11    subprocess.run(split(f"git tag -a '{tag}' -m 'model {tag}, {img_count} images'"))

The next task checks the the accuracy score of our model on our validation dataset. We raise an error if our latest model doesn’t meet our threshold of 95% accuracy.

1@task
2def check_model_val(history):
3    """Check accuracy score of validation set"""
4    if history.history["val_accuracy"][-1] < 0.95:
5        raise ValueError(f"Validation accuracy is too low:{history.history['val_accuracy'][-1]}")

➡️ Rollback logic to previous state: The next function is where the magic happens. We roll back the workspace to the previous state when it fires. Note that this function is decorated with @git_track.on_rollback as the decorator. This rollback hook function will fire automatically if an error is raised.

1@git_track.on_rollback
2def rollback_workspace(txn):
3    """Automatically roll back the workspace to the previous commit if model evaluation fails"""
4    subprocess.run(split("git checkout HEAD~1"))
5    subprocess.run(split("dvc checkout"))
6    tag = txn.get("tagging")
7    subprocess.run(split("git tag -d {tag}"))
8    print(f"Rolling back workspace from {tag} to previous commit because validation accuracy was too low")

Inside the function, we move the state of our workspace back to the most recent Git commit and call dvc checkout, which rolls back the DVC history.

We fetch the value of our tag from the key we will set in the transaction.

Then we remove the git tag for v2.0 to keep things tidy.

➡️ Organize everything under a flow: Finally, we have an @flow decorated pipeline function. A flow is a container for workflow logic. This assembly function calls all the functions discussed above.

1@flow
2def pipeline(dataset_name: str, tag: str, img_count: int, initial_run: bool = False):
3    """Pipeline for training model and checking validation accuracy"""
4    add_data(dataset_name=dataset_name)
5    history = train_model()
6    with transaction() as txn:
7        git_track(tag=tag, img_count=img_count)
8        txn.set("tagging", tag)
9        if not initial_run:
10            check_model_val(history=history)

First we fetch the data and train the model.

Then we nest the git_track task call in the with transaction context block to add transactional capabilities.

The following line allows us to access the value of tag with the key tagging elsewhere in the transaction, even if an error was raised.

1txn.set("tagging", tag)

We could set other key-value pairs for use within the transaction if desired.

Finally, the if statement is nested as part of the transaction. If it evaluates to True, then the check_model_val function is called. We added this check because we only want to roll back after we have a model to roll back to.

With transactions, Prefect provides you with the flexibility to handle failure gracefully, running on_rollback hooks and setting and getting any values you need along the way.

In the if __name__ == "__main__" block we call our pipeline flow function twice - once with the arguments for the initial dataset and then once with the second dataset.

1if __name__ == "__main__":
2    dataset_info = [["data", "v1.0", 1000], ["new-labels", "v2.0", 2000]]
3    pipeline(*dataset_info[0], initial_run=True)
4    pipeline(*dataset_info[1])

Step 5: inspect results

As your script downloads the data and your model trains you’ll notice artifacts and DVC files get added into your file system. You can watch the model training progress in the terminal.

The first metrics.csv file is interesting to examine. After the fist run you’ll see something like this.

1epoch,accuracy,loss,val_accuracy,val_loss
2...
39,0.9570000171661377,0.13074757158756256,0.90625,0.34719613194465637

Looks like my validation accuracy is .906525 in the final epoch. The results are not deterministic, so you’ll see likely something different if you’re training a model.

As the second run completes these values are replaced by the results of the second model. Then, faster than you can say “ml model magic” those get rolled back.

If you’re using VSCode and want to see the second model’s results, look in the TIMELINE tab in the bottom of the left side bar with metrics.csv selected. Check out the diff.

1epoch,accuracy,loss,val_accuracy,val_loss
2...
39,0.9539999961853027,0.12576799094676971,0.9012500047683716,0.3672797381877899

We’ve got new results, but because they didn’t meet the threshold in our check_model_val task, we raised an error.

The error would normally just terminate our pipeline on the spot. What is done is done. No take backs. But with transactions, we can hit the undo button, figuratively.

Let’s look at the Git logs in the terminal with git log. You should see something like this at the top:

1commit 0d3c061338cfa0ae94c2d1ab2b9ef485b4f2030a (HEAD)
2Author: Jeff Hale <jeffmshale@gmail.com>
3Date:   Mon Sep 9 15:43:57 2024 -0400
4
5    v1.0 model, trained with 1000 images

Our rollback_workspace function moved us backward in our git history from the commit made in the second run of our git_track task. Further, our function code deleted our v2.0 tag.

If we take a peak in the UI, we’ll see the results of our two flow runs.

The most recent flow run failed, but the first one succeeded. Clicking on the failed run, we see that the rollback hook ran and finished successfully.

The nicely formatted logs make it easy to see what happened within our workflow.

Using rollbacks, we could even undo code changes in other training scripts if we were experimenting with different architectures. Also, we could use other Git or DVC commands to undo to other states, if desired.

Highlight: setting and getting key-value pairs within transactions

In our example we used the transaction's set method to save data for use if our transaction didn’t get committed. Then, we accessed the key-value pair that we set from within the rollback_workspace function. This was slick. In regular Python code we couldn’t have easily accessed the value once an error was raised. Our code would have immediately failed, without additional context that could be used by another function.

Bonus: Improve efficiency with result caching and transactions

Transactions are the backbone of Prefect’s task run and caching features. Transactions operate beneath the surface, and are not something you need to think about most of the time. However, to help build understanding, let’s look at the lifecycle of a transaction.

Transaction stages overview

Each transaction goes through at most four lifecycle stages.

BEGIN: transaction’s key is computed and looked up. If a record already exists at the key location the transaction considers itself committed. This is a cached task. Nothing runs.

STAGE: stages a piece of data to be committed to its result location. 

ROLLBACK: if the transaction encounters any error after staging, it rolls itself back and does not commit anything. Any on_rollback hooks fire.

COMMIT: the transaction writes its record to its configured location.

In the first run of our model training the transaction made it from BEGIN, to STAGE, to COMMIT. The flow run succeeded.

In the second run of our model training the transaction made it through BEGIN and STAGE and then ROLLBACK, but never reached COMMIT. The flow run failed and the on_rollback hook fired. If a task within the transaction had a return value, that result would never have been committed.

These semantics are inspired by SQL database semantics and help engineers in their drive for idempotent, resilient operations.

Let’s look at one more place where Prefect’s transactions can facilitate improved workflow orchestration.

Caching to save on compute

If you want to store the return value of a task run and don’t want to have to rerun the task unless certain circumstances are met, you can use Prefect’s caching feature to save compute, time, and money. For example, you might only want to rerun a task if different parameter values are passed to the function.

You can accomplish that goal by passing INPUTS as the argument to the @task decorator’s cache_policy.

1from prefect import task
2from prefect.cache_policies import INPUTS
3
4@task(cache_policy=INPUTS)
5def my_cached_task(x: int):
6    print('running...')
7    return x + 42
8
9
10my_cached_task(1)  # Task runs 
11my_cached_task(1)  # Task doesn't run, uses the cached result
12my_cached_task(34) # Task runs

Prefect 3.0 provides other built in cache key options and the ability for you to create your own custom cache policies and cache key functions.

If you want a cache to expire after a certain amount of time, just pass a timedelta with the cache_expiration keyword on the task decorator. This keyword accepts a datetime.timedelta specifying a duration for which the cached value is valid.

Because Prefect 3.0 task caching is built on transactions, you can group tasks that should only rerun as a group by calling them in a with transaction block. This way either all tasks in the block run together or not at all.

See the task caching docs for more details.

Conclusion

You’ve seen how DVC and Prefect 3.0's transaction capabilities allow you to easily rollback a pipeline to a previous state if a model evaluation fails. Prefect’s transaction semantics offer you the flexibility to automatically take healing actions when things fail - and when it comes to data and ML pipelines, that’s the resiliency you need.

Connect with one of our engineers for a more detailed demo of you Prefect can help you scale efficiently and build trust - book a demo here.