AI/ML Framework Integration with DataHub
Why Integrate Your AI/ML System with DataHub?
As a data practitioner, keeping track of your AI experiments, models, and their relationships can be challenging. DataHub makes this easier by providing a central place to organize and track your AI assets.
This guide will show you how to integrate your AI workflows with DataHub. With integrations for popular ML platforms like MLflow and Amazon SageMaker, DataHub enables you to easily find and share AI models across your organization, track how models evolve over time, and understand how training data connects to each model. Most importantly, it enables seamless collaboration on AI projects by making everything discoverable and connected.
Goals Of This Guide
In this guide, you'll learn how to:
- Create your basic AI components (models, experiments, runs)
- Connect these components to build a complete AI system
- Track relationships between models, data, and experiments
Core AI Concepts
Here's what you need to know about the key components in DataHub:
- Experiments are collections of training runs for the same project, like all attempts to build a churn predictor
- Training Runs are attempts to train a model within an experiment, capturing parameters and results
- Model Groups organize related models together, like all versions of your churn predictor
- Models are successful training runs registered for production use
The hierarchy works like this:
- Every run belongs to an experiment
- Successful runs can be registered as models
- Models belong to a model group
- Not every run becomes a model
Different AI platforms (MLflow, Amazon SageMaker) have their own terminology. To keep things consistent, we'll use DataHub's terms throughout this guide. Here's how DataHub's terminology maps to these platforms:
DataHub | Description | MLflow | SageMaker |
---|---|---|---|
ML Model Group | Collection of related models | Model | Model Group |
ML Model | Versioned artifact in a model group | Model Version | Model Version |
ML Training Run | Single training attempt | Run | Training Job |
ML Experiment | Collection of training runs | Experiment | Experiment |
For platform-specific details, see our integration guides for MLflow and Amazon SageMaker.
Basic Setup
To follow this tutorial, you'll need DataHub Quickstart deployed locally. For detailed steps, see the Datahub Quickstart Guide.
Next, set up the Python client for DataHub using DatahubAIClient
defined in here.
Create a token in DataHub UI and replace <your_token>
with your token:
from dh_ai_client import DatahubAIClient
client = DatahubAIClient(token="<your_token>", server_url="http://localhost:9002")
Throughout this guide, we'll show how to verify changes using GraphQL queries.
You can run these queries in the DataHub UI at https://localhost:9002/api/graphiql
.
Create Simple AI Assets
Let's create the basic building blocks of your ML system. These components will help you organize your AI work and make it discoverable by your team.
Create a Model Group
A model group contains different versions of a similar model. For example, all versions of your "Customer Churn Predictor" would go in one group.
- Basic
- Advanced
client.create_model_group(
group_id="airline_forecast_models_group",
)
client.create_model_group(
group_id="airline_forecast_models_group",
properties=models.MLModelGroupPropertiesClass(
name="Airline Forecast Models Group",
description="Group of models for airline passenger forecasting",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
),
)
Let's verify that our model group was created:
- UI
- GraphQL
query {
mlModelGroup(
urn:"urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,airline_forecast_models_group,PROD)"
) {
name
description
}
}
The response will show your model group's details:
{
"data": {
"mlModelGroup": {
"name": "airline_forecast_models_group",
"description": "Group of models for airline passenger forecasting"
}
}
}
Create a Model
Next, let's create a specific model version that represents a trained model ready for deployment.
- Basic
- Advanced
client.create_model(
model_id="arima_model",
version="1.0",
)
client.create_model(
model_id="arima_model",
properties=models.MLModelPropertiesClass(
name="ARIMA Model",
description="ARIMA model for airline passenger forecasting",
customProperties={"team": "forecasting"},
trainingMetrics=[
models.MLMetricClass(name="accuracy", value="0.9"),
models.MLMetricClass(name="precision", value="0.8"),
],
hyperParams=[
models.MLHyperParamClass(name="learning_rate", value="0.01"),
models.MLHyperParamClass(name="batch_size", value="32"),
],
externalUrl="https:localhost:5000",
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
tags=["forecasting", "arima"],
),
version="1.0",
alias="champion",
)
Let's verify our model:
- UI
- GraphQL
query {
mlModel(
urn:"urn:li:mlModel:(urn:li:dataPlatform:mlflow,arima_model,PROD)"
) {
name
description
versionProperties {
version {
versionTag
}
}
}
}
The response will show your model's details:
{
"data": {
"mlModel": {
"name": "arima_model",
"description": "ARIMA model for airline passenger forecasting",
"versionProperties": {
"version": {
"versionTag": "1.0"
}
}
}
}
}
Create an Experiment
An experiment helps organize multiple training runs for a specific project.
- Basic
- Advanced
client.create_experiment(
experiment_id="airline_forecast_experiment",
)
client.create_experiment(
experiment_id="airline_forecast_experiment",
properties=models.ContainerPropertiesClass(
name="Airline Forecast Experiment",
description="Experiment to forecast airline passenger numbers",
customProperties={"team": "forecasting"},
created=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
lastModified=models.TimeStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
),
)
Verify your experiment:
- UI
- GraphQL
query {
container(
urn:"urn:li:container:airline_forecast_experiment"
) {
name
description
properties {
customProperties
}
}
}
Check the response:
{
"data": {
"container": {
"name": "Airline Forecast Experiment",
"description": "Experiment to forecast airline passenger numbers",
"properties": {
"customProperties": {
"team": "forecasting"
}
}
}
}
}
Create a Training Run
A training run captures all details about a specific model training attempt.
- Basic
- Advanced
client.create_training_run(
run_id="simple_training_run_4",
)
client.create_training_run(
run_id="simple_training_run_4",
properties=models.DataProcessInstancePropertiesClass(
name="Simple Training Run 4",
created=models.AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=models.MLTrainingRunPropertiesClass(
id="simple_training_run_4",
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")],
externalUrl="https:localhost:5000",
),
run_result=RunResultType.FAILURE,
start_timestamp=1628580000000,
end_timestamp=1628580001000,
)
Verify your training run:
- UI
- GraphQL
query {
dataProcessInstance(
urn:"urn:li:dataProcessInstance:simple_training_run_4"
) {
name
created {
time
}
properties {
customProperties
}
}
}
Check the response:
{
"data": {
"dataProcessInstance": {
"name": "Simple Training Run 4",
"created": {
"time": 1628580000000
},
"properties": {
"customProperties": {
"team": "forecasting"
}
}
}
}
}
Define Entity Relationships
Now let's connect these components to create a comprehensive ML system. These connections enable you to track model lineage, monitor model evolution, understand dependencies, and search effectively across your ML assets.
Add Model To Model Group
Connect your model to its group:
client.add_model_to_model_group(model_urn=model_urn, group_urn=model_group_urn)
- UI
- GraphQL
View model versions in the Model Group under the Models section:
Find group information in the Model page under the Group tab:
query {
mlModel(
urn:"urn:li:mlModel:(urn:li:dataPlatform:mlflow,arima_model,PROD)"
) {
name
properties {
groups {
urn
properties {
name
}
}
}
}
}
Check the response:
{
"data": {
"mlModel": {
"name": "arima_model",
"properties": {
"groups": [
{
"urn": "urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,airline_forecast_model_group,PROD)",
"properties": {
"name": "Airline Forecast Model Group"
}
}
]
}
}
}
}
Add Run To Experiment
Connect a training run to its experiment:
client.add_run_to_experiment(run_urn=run_urn, experiment_urn=experiment_urn)
- UI
- GraphQL
Find your runs in the Experiment page under the Entities tab:
See the experiment details in the Run page:
query {
dataProcessInstance(
urn:"urn:li:dataProcessInstance:simple_training_run"
) {
name
parentContainers {
containers {
urn
properties {
name
}
}
}
}
}
View the relationship details:
{
"data": {
"dataProcessInstance": {
"name": "Simple Training Run",
"parentContainers": {
"containers": [
{
"urn": "urn:li:container:airline_forecast_experiment",
"properties": {
"name": "Airline Forecast Experiment"
}
}
]
}
}
}
}
Add Run To Model
Connect a training run to its resulting model:
client.add_run_to_model(model_urn=model_urn, run_urn=run_urn)
This relationship enables you to:
- Track which runs produced each model
- Understand model provenance
- Debug model issues
- Monitor model evolution
- UI
- GraphQL
Find the source run in the Model page under the Summary tab:
See related models in the Run page under the Lineage tab:
query {
mlModel(
urn:"urn:li:mlModel:(urn:li:dataPlatform:mlflow,arima_model,PROD)"
) {
name
properties {
mlModelLineageInfo {
trainingJobs
}
}
}
}
View the relationship:
{
"data": {
"mlModel": {
"name": "arima_model",
"properties": {
"mlModelLineageInfo": {
"trainingJobs": [
"urn:li:dataProcessInstance:simple_training_run_test"
]
}
}
}
}
}
Add Run To Model Group
Create a direct connection between a run and a model group:
client.add_run_to_model_group(model_group_urn=model_group_urn, run_urn=run_urn)
This connection lets you:
- View model groups in the run's lineage
- Query training jobs at the group level
- Track training history for model families
- UI
- GraphQL
See model groups in the Run page under the Lineage tab:
query {
mlModelGroup(
urn:"urn:li:mlModelGroup:(urn:li:dataPlatform:mlflow,airline_forecast_model_group,PROD)"
) {
name
properties {
mlModelLineageInfo {
trainingJobs
}
}
}
}
Check the relationship:
{
"data": {
"mlModelGroup": {
"name": "airline_forecast_model_group",
"properties": {
"mlModelLineageInfo": {
"trainingJobs": [
"urn:li:dataProcessInstance:simple_training_run_test"
]
}
}
}
}
}
Add Dataset To Run
Track input and output datasets for your training runs:
client.add_input_datasets_to_run(
run_urn=run_urn,
dataset_urns=[str(input_dataset_urn)]
)
client.add_output_datasets_to_run(
run_urn=run_urn,
dataset_urns=[str(output_dataset_urn)]
)
These connections help you:
- Track data lineage
- Understand data dependencies
- Ensure reproducibility
- Monitor data quality impacts
Find dataset relationships in the Lineage tab of either the Dataset or Run page:
Full Overview
Here's your complete ML system with all components connected:
You now have a complete lineage view of your ML assets, from training data through runs to production models!
You can check out the full code for this tutorial here.
What's Next?
To see these integrations in action:
- Watch our Townhall demo showcasing the MLflow integration
- Explore our detailed documentation: