Object detection
Step-by-step tutorial explains how to integrate custom object detection neural network into Supervisely platform
Last updated
Step-by-step tutorial explains how to integrate custom object detection neural network into Supervisely platform
Last updated
This tutorial will teach you how to integrate your custom object detection model into Supervisely by using ObjectDetectionTrainDashboad class.
Full code of object detecting sample app can be found here
Step 1. Prepare ~/supervisely.env
file with credentials. Learn more here.
Step 2. Clone repository with source code and create Virtual Environment.
git clone https://github.com/supervisely-ecosystem/object-detection-training-template
cd object-detection-training-template
./create_venv.sh
Step 3. Open the repository directory in Visual Studio Code.
code -r .
Step 4. Start debugging src/main.py
The integration of your own NN with ObjectDetectionTrainDashboad is really simple:
Step 1. Define pytorch dataset
import json
import random
from typing import Literal, List, Dict, Optional, Any, Tuple, Union
import numpy as np
import imgaug.augmenters as iaa
import supervisely as sly
from torch.utils.data import Dataset
###########################
### 1. Redefine your dataset and methods or use them as is
###########################
class CustomDataset(Dataset):
def __init__(self,
items: sly.ImageInfo,
project_meta: sly.ProjectMeta,
classes: List[str],
input_size: Union[List, Tuple],
transforms: Optional[iaa.Sequential] = None):
self.items = items
self.project_meta = project_meta
self.classes = classes
self.input_size = input_size
self.transforms = transforms
def __getitem__(self, index: int):
image = sly.image.read(self.items[index].img_path)
with open(self.items[index].ann_path, 'r') as f:
ann = sly.Annotation.from_json(json.loads(f.read()), self.project_meta)
if self.transforms:
res_meta, image, ann = sly.imgaug_utils.apply(self.transforms, self.project_meta, image, ann)
meta, image, ann = sly.imgaug_utils.apply(
iaa.Sequential([
iaa.Resize({"height": self.input_size[1], "width": self.input_size[0]})
]),
self.project_meta,
image,
ann
)
# our simple model can work only with 1 bbox per image
# you need to change code if your model more complicated
label = random.choice(ann.labels)
class_id = self.classes.index(label.obj_class.name)
bbox = label.geometry.to_bbox()
bbox = np.array([bbox.left, bbox.top, bbox.right, bbox.bottom])
image = np.rollaxis(image, 2)
return image, class_id, bbox
def __len__(self):
return len(self.items)
Step 2. Define pytorch object detection model
from typing import Literal, List, Dict, Optional, Any, Tuple, Union
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
###########################
### 2. Define neural network model
###########################
class CustomModel(nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
resnet = models.resnet34()
layers = list(resnet.children())[:8]
self.features1 = nn.Sequential(*layers[:6])
self.features2 = nn.Sequential(*layers[6:])
self.classifier = nn.Sequential(nn.BatchNorm1d(512), nn.Linear(512, 4))
self.bb = nn.Sequential(nn.BatchNorm1d(512), nn.Linear(512, 4))
def forward(self, x):
x = self.features1(x)
x = self.features2(x)
x = F.relu(x)
x = nn.AdaptiveAvgPool2d((1,1))(x)
x = x.view(x.shape[0], -1)
return self.classifier(x), self.bb(x)
Step 3. Define subclass ObjectDetectionTrainDashboad
and implement train
method
import os
from typing import Literal, List, Dict, Optional, Any, Tuple, Union
import torch
import torch.nn.functional as F
import supervisely as sly
from torch.utils.data import DataLoader
import src.sly_globals as g
from src.dashboard import ObjectDetectionTrainDashboad
###########################
### 3. Define dashboard
###########################
class CustomTrainDashboard(ObjectDetectionTrainDashboad):
def train(self):
# getting selected classes from UI
classes = self._classes_table.get_selected_classes()
# getting configuration of splits from UI
train_set, val_set = self.get_splits()
# getting hyperparameters from UI
hparams = self.get_hyperparameters()
# initialized with selected hyperparameters pytorch optimizer will be returned
optimizer = self.get_optimizer(hparams)
device = hparams.general.device
# extra hparam to scale loss
C = hparams.general.C
# getting selected augmentation from UI
transforms = self.get_transforms()
train_dataset = CustomDataset(
train_set,
project_meta=g.project_fs.meta,
classes=classes,
input_size=hparams.general.input_image_size,
transforms=transforms,
)
val_dataset = CustomDataset(
val_set,
project_meta=g.project_fs.meta,
classes=classes,
input_size=hparams.general.input_image_size
)
train_loader = DataLoader(
train_dataset,
batch_size=hparams.general.batch_size,
shuffle=True,
num_workers=hparams.general.workers_number
)
val_loader = DataLoader(
val_dataset,
batch_size=hparams.general.batch_size,
num_workers=hparams.general.workers_number
)
# it will return None if pretrained model weights isn't selected in UI
if self.pretrained_weights_path:
self.model.load_state_dict(torch.load(self.pretrained_weights_path))
sly.logger.info('Model checkpoint successfully loaded.')
model.to(device)
with self.progress_bar(message=f"Training...", total=hparams.general.number_of_epochs) as pbar:
self.model.train()
# change training and eval loops for your model if needed
for epoch in range(hparams.general.number_of_epochs):
train_total_samples = 0
train_sum_loss = 0
train_correct = 0
for batch_idx, (images, classes, bboxes) in enumerate(train_loader):
batch_size = images.shape[0]
images = images.to(device).float()
classes = classes.to(device)
bboxes = bboxes.to(device).float()
pred_classes, pred_bboxes = self.model(images)
loss_class = F.cross_entropy(pred_classes, classes, reduction="sum")
loss_bb = F.l1_loss(pred_bboxes, bboxes, reduction="none").sum(1)
loss_bb = loss_bb.sum()
loss = loss_class + loss_bb / C
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_total_samples += batch_size
train_sum_loss += loss.item()
_, pred = torch.max(pred_classes, 1)
train_correct += pred.eq(classes).sum().item()
train_loss = train_sum_loss / train_total_samples
train_accuracy = train_correct / train_total_samples
if hasattr(hparams.intervals, 'validation'):
if epoch % hparams.intervals.validation == 0:
model.eval()
val_total_samples = 0
val_sum_loss = 0
val_correct = 0
for batch_idx, (images, classes, bboxes) in enumerate(val_loader):
batch_size = images.shape[0]
images = images.to(device).float()
classes = classes.to(device)
bboxes = bboxes.to(device).float()
pred_classes, pred_bboxes = self.model(images)
loss_class = F.cross_entropy(pred_classes, classes, reduction="sum")
loss_bb = F.l1_loss(pred_bboxes, bboxes, reduction="none").sum(1)
loss_bb = loss_bb.sum()
loss = loss_class + loss_bb / C
val_sum_loss += loss.item()
val_total_samples += batch_size
_, pred = torch.max(pred_classes, 1)
val_correct += pred.eq(classes).sum().item()
val_loss = val_sum_loss / val_total_samples
val_accuracy = val_correct / val_total_samples
if hasattr(hparams.intervals, 'сheckpoints'):
if epoch % hparams.intervals.сheckpoints == 0:
temp_checkpoint_path = os.path.join(g.checkpoints_dir, f'model_epoch_{epoch}.pth')
torch.save(self.model.state_dict(), temp_checkpoint_path)
self.log('add_text', tag='Main logs', text_string=f"Model saved at:\n{temp_checkpoint_path}")
if epoch % hparams.intervals.logging == 0:
# common method to logging your values to dashboard
# you log values only by your own logger if it setted just call it by class name
# self.loggers.YOUR_LOGGER.add_scalar(tag='Loss/train', scalar_value=train_loss, global_step=epoch)
self.log('add_scalar', tag='Loss/train', scalar_value=train_loss, global_step=epoch)
self.log('add_scalar', tag='Loss/val', scalar_value=val_loss, global_step=epoch)
self.log('add_scalar', tag='Accuracy/train', scalar_value=train_accuracy, global_step=epoch)
self.log('add_scalar', tag='Accuracy/val', scalar_value=val_accuracy, global_step=epoch)
self.log('add_text', tag='Main logs', text_string=f"Epoch: {epoch}\t|\tTrain loss: {train_loss:.3f}\t|\tVal loss: {val_loss:.3f}\t|\tTrain accuracy: {train_accuracy:.3f}\t|\tVal accuracy: {val_accuracy:.3f}")
pbar.update(1)
pbar.set_description_str("Training has been successfully finished")
Step 4. Configure your dashboard using parameters and run the app. That's all. 😎
###########################
### 4. Run dashboard app
###########################
dashboard = CustomTrainDashboard(
# REQUIRED
# your neural network to train
model=model,
# These titles will be used for logging values while training as part of tags
# self.log('add_scalar', tag='Loss/train', scalar_value=train_loss, global_step=epoch)
plots_titles=['Loss', 'Accuracy'],
# OPTIONAL
# This row will add an additional hyperparam to the UI.
# You will have access to this hyperparam inside CustomTrainDashboard as
# hparams = self.get_hyperparameters()
# C = hparams['general']['C']
extra_hyperparams={
'general': [
dict(key='C',
title='Bbox loss scale',
description='Divide bbox_loss for this value',
content=InputNumber(1000, min=1, max=100000, size='small')),
],
},
)
app = dashboard.run()
This section provide detailed information about parameters for ObjectDetectionTrainDashboad initialize and how to change it.
class ObjectDetectionTrainDashboad:
def __init__(
self,
# ↓↓↓ required ↓↓↓
model,
plots_titles
# ↓↓↓ optional ↓↓↓
pretrained_weights
hyperparameters_categories
extra_hyperparams
hyperparams_edit_mode
show_augmentations_ui
extra_augmentation_templates
download_batch_size
loggers
):
...
pretrained_weights: Dict
- it defines the table of pretraned model weights in UI
If the provided path doesn't exist in the local filesystem at sly_globals.checkpoints_dir
, it will be downloaded from Team files.
You can read more about sly_global
in the Additional notes section
Example
pretrained_weights = {
'columns': ['Name', 'Description', 'Path'], # table headers
'rows': [
# The path can be local path, team files path or url
['Unet', 'Vanilla Unet', 'data/checkpoints/unet.pth'], # local path
['Unet-11', 'VGG16', '/data/checkpoints/unet11.pth'], # team files path
['Unet-16', 'VGG11', 'https://your_file_server/unet16.pth'] # url (in the future releases)
]
}
The "Pretrained weights" tab will appear in the model settings card automatically.
hyperparameters_categories: List
- list of tabs names in hyperparameters UI.
Default: ['general', 'checkpoints', 'optimizer', 'intervals', 'scheduler']
These names also will be used as parent keys for hyperparams from corresponding tabs. You can add/delete tabs by this parameter in hyperparameters card.
Example
dashboard = CustomTrainDashboard(
...
hyperparameters_categories = ['general', 'intervals']
)
Before
After
extra_hyperparams: Dict
- they will be added at the end of list hyperparams in the tab by passed tab name, which used as parent key.
Extra hyperparam structure
{
'any_tab_name': [
{
'any_key': str, # by key you will get access to widget value. hparams.any_tab_name.any_key
'title': str,
'description': str,
'content': sly.Widget # any widget from sly.app.widgets with "get_value" method
}
]
}
any_tab_name
should be unique string.
any_key
should be unique string for corresponding tab, but it can be repeatet on a another tab. See example below.
content
work correctly only with sly.app.widgets
, which have get_value
method.
In other cases you have two options:
implement get_value
method for your widget
modify get_hyperparameters
method for support custom widgets
Example:
extra_hyperparams={
# adding "addition_hparam1" and "addition_hparam2" to "general" tab
'general': [
dict(key='addition_hparam_1',
title='Addition hyperparameter 1',
description='Some description about this hyperparameter',
content=InputNumber(1, min=1, max=1000, size='small')),
dict(key='addition_hparam_2',
title='Addition hyperparameter 1',
description='Some description about this hyperparameter',
content=InputNumber(6, min=2, max=10, step=2, size='small')),
],
# adding duplicated "addition_hparam1" to "checkpoints" tab
'checkpoints': [
dict(key='addition_hparam_1',
title='Addition hyperparameter 1',
description='Some description about this hyperparameter',
content=InputNumber(0.0001, min=0.0001, step=0.0001, size='small')),
],
}
dashboard = CustomTrainDashboard(
...
extra_hyperparams = extra_hyperparams
)
hyperparams_edit_mode: String
- the ways to define hyperparameters.
Default: `'ui'`
Supported values: ['ui', 'raw', 'all'
]
ui
- only 🟢 section will be shown.
raw
- only 🔴 section will be shown.
all
- 🟢 + 🔴 sections will be shown together.
The hyperparams from UI will overwrite hyperparams with the same names from the text editor widget.
For example, if you declare hparam_1
with "general" as the parent key in extra_hyperparams or in hyperparameters_ui method
'general': [
dict(key='hparam_1',
title='Hyperparameter 1',
description='Some description',
content=InputNumber(100, min=100, max=1000, size='small')),
]
and declare the same in the text editor widget
general:
hparam_1: 0.1
then when you will call get_hyperparameters
method, the hparam_1
value will be equal to 100
, not 0.1
.
show_augmentations_ui: Bool
- show/hide flag for augmentations card
Default: True
extra_augmentation_templates: List
- these augmentations templates will be added to beginning of the list for selector in augmentations card:
You can create your own augmentations template .json
using ImgAug Studio app.
Example:
AUG_TEMPLATES = [
# label - just title for selector option
# value - local path
{'label': 'My aug 1', 'value':'aug_templates/light.json'},
{'label': 'My aug 2', 'value':'aug_templates/light_corrupt.json'},
{'label': 'My aug 3', 'value':'aug_templates/medium.json'},
]
If you will set hyperparams_edit_mode to raw
or all
, this additional widget will be shown.
download_batch_size: int
- How much data to download per batch. Increase this value for speedup download on big projects.
Default: 100
loggers: List
- additional user loggers
Example:
from torch.utils.tensorboard import SummaryWriter
class CSVWriter:
def __init__(self, log_dir):
# your code
pass
def add_scalar(tag, scalar_value, global_step):
# your code
pass
my_csv_logger = CSVWriter(g.csv_log_dir)
my_tensorboard = SummaryWriter(g.tensorboard_runs_dir)
loggers=[my_csv_logger, my_tensorboard]
You can log value for all loggers by calling common method.
All passed loggers should have the called method.
self.log(method='add_scalar', tag='Loss/train', scalar_value=train_loss, global_step=epoch)
If you want to log value for specific logger, then use self.loggers.YOUR_LOGGER_CLASS
self.loggers.SummaryWriter.add_scalar(tag='Loss/train', scalar_value=train_loss, global_step=epoch)
All what you neeed is just re-define hyperparameters_ui
method in subclass of ObjectDetectionTrainDashboad
class CustomTrainDashboard(ObjectDetectionTrainDashboad):
def hyperparameters_ui(self):
hparams_widgets = {}
# adding widgets to "my_hparam_tab" in IU. "my_hparam_tab" will be used as parent key.
if 'my_hparam_tab' in self._hyperparameters_categories:
hparams_widgets['my_hparam_tab'] = [
# adding hparam
dict(key='key_for_hparam',
title='Name for added hyperparameter',
description='Any description for added hyperparameter',
# any widget from sly.app.widgets with "get_value" method or
# you should redefine "get_hyperparameters" method to handle with these widgets
content=InputNumber(10, min=1, max=100000, size='small')),
# this string required for adding additional widgets to "my_hparam_tab"
*self._extra_hyperparams.get('my_hparam_tab', [])
]
# Add new tab name to the list for displaying them in UI
hparams_tabs = ['my_hparam_tab']
dashboard = CustomTrainDashboard(
...
hyperparameters_categories = hparams_tabs
)
Environment variable SLY_APP_DATA_DIR
in src.globals
is used to provide access to app files when the app will be finished. If something went wrong in your training process at any moment - you won't lose checkpoints and other important artifacts. They will be available by SFTP.
By default object detection training template app use this directoties structure from src/sly_globals
:
|object-detection-training-template
├─ project_dir # project training data destination folder
└─ data_dir # All training artefacts. This dir will be saved in Team files at `remote_data_dir` at the end of training process.
├─ checkpoints_dir # Model checkponts will be saved here. This dir included in `data_dir`.
└─ tensorboard_runs_dir # This dir will be created if tensorboard ResultsWriter was passed in loggers list
remote_data_dir
= f"/train_dashboard/{project.name}/runs/{time.strftime('%Y-%m-%d %H:%M:%S')}"
- the destination dir in Team files for all training artefacts.
The General tab
The Checkpoints tab