deep-ml is a high-level PyTorch training framework that simplifies deep learning workflows for computer vision tasks. It provides easy-to-use trainers with distributed training support, comprehensive task implementations, and seamless experiment tracking.
pip install deepml
# For Lightning Fabric
pip install deepml lightning-fabric
# For HuggingFace Accelerate
pip install deepml accelerate
# For MLflow tracking
pip install deepml mlflow
# For Weights & Biases
pip install deepml wandb
# For Albumentations (segmentation)
pip install deepml albumentations
from deepml.tasks import ImageClassification
from deepml.fabric_trainer import FabricTrainer
import torch
from torch.optim import Adam
from torchvision.models import resnet18
# 1. Define your model
model = resnet18(num_classes=10)
# 2. Create a task
task = ImageClassification(
model=model,
model_dir="./checkpoints",
classes=['cat', 'dog', 'bird', ...] # Optional
)
# 3. Setup optimizer and loss
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
# 4. Create trainer
trainer = FabricTrainer(
task=task,
optimizer=optimizer,
criterion=criterion,
accelerator="auto", # Use GPU if available
devices="auto", # Use all available devices
precision="16-mixed" # Mixed precision training
)
# 5. Train!
trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
epochs=50
)
# 6. Visualize predictions
task.show_predictions(loader=val_loader, samples=9)
from deepml.tasks import Segmentation
from deepml.fabric_trainer import FabricTrainer
from deepml.losses import JaccardLoss
# Define model (e.g., U-Net)
model = UNet(in_channels=3, out_channels=1)
# Create task
task = Segmentation(
model=model,
model_dir="./checkpoints",
mode="binary",
num_classes=1,
threshold=0.5
)
# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.BCEWithLogitsLoss()
trainer = FabricTrainer(task=task, optimizer=optimizer, criterion=criterion)
# Train
trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
epochs=100
)
📚 Full documentation is available at: https://deep-ml.readthedocs.io/
Getting Started
User Guide
API Reference
Additional Resources
📖 See the complete tutorials on ReadTheDocs.
# Multi-GPU training with DDP
trainer = FabricTrainer(
task=task,
optimizer=optimizer,
criterion=criterion,
accelerator="gpu",
strategy="ddp",
devices="auto" # Use all GPUs
)
# Simulate larger batch sizes
trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
epochs=50,
gradient_accumulation_steps=4 # Effective batch = 4x
)
from deepml.lr_scheduler_utils import setup_one_cycle_lr_scheduler_with_warmup
lr_scheduler_fn = lambda opt: setup_one_cycle_lr_scheduler_with_warmup(
optimizer=opt,
steps_per_epoch=len(train_loader),
warmup_ratio=0.1,
num_epochs=50,
max_lr=1e-3
)
trainer = FabricTrainer(
...,
lr_scheduler_fn=lr_scheduler_fn
)
from deepml.tracking import MLFlowLogger, WandbLogger
# MLflow
logger = MLFlowLogger(
experiment_name='my-experiment',
tracking_uri='./mlruns'
)
# Weights & Biases
logger = WandbLogger(
project='my-project',
name='experiment-1'
)
trainer.fit(..., logger=logger)
| Task | Description | Typical Use Cases |
|---|---|---|
ImageClassification |
Single-label classification | CIFAR-10, ImageNet |
MultiLabelImageClassification |
Multi-label classification | Object attributes |
Segmentation |
Pixel-level classification | Medical imaging, autonomous driving |
ImageRegression |
Continuous value prediction | Age estimation, depth prediction |
NeuralNetTask |
Generic task template | Custom tasks |
Contributions are welcome! See our Contributing Guide for guidelines.
git clone https://github.com/sagar100rathod/deep-ml.git
cd deep-ml
pip install -e ".[dev]"
pytest # Run tests
This project is licensed under the MIT License - see the LICENSE file for details.
If you find this project useful, please consider giving it a star!
If you use deep-ml in your research, please cite:
@software{deepml2026,
author = {Rathod, Sagar},
title = {deep-ml: High-level PyTorch Training Framework for Computer Vision},
year = {2026},
version = {3.0.0},
url = {https://github.com/sagar100rathod/deep-ml},
doi = {10.5281/zenodo.XXXXXXX}
}