Skip to content

Experimental Detection API

obia.detection is experimental. Its interfaces and workflow may change in future releases.

obia.detection.dataset

obia.detection.dataset

TreeDetectionDataset

Bases: Dataset

Represents a dataset for tree detection tasks.

This class handles loading, preprocessing, and transforming tree detection datasets. Images and annotations are loaded and preprocessed for deep learning models. It supports geometric and color augmentations if transforms are provided, and optional scaling of pixel values.

:ivar images_dir: Path to the directory containing image files. :type images_dir: str :ivar annotations: Parsed annotations for the dataset, loaded from the JSON file. :type annotations: dict :ivar image_ids: List of image IDs corresponding to the keys in the annotations. :type image_ids: list :ivar transforms: A callable for data augmentation and transformations. It must support the image, bboxes, and labels keys for input and output. :type transforms: callable, optional :ivar do_scale: Whether to scale image pixel values to the range 0-255. :type do_scale: bool

Source code in obia/detection/dataset.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class TreeDetectionDataset(Dataset):
    """
    Represents a dataset for tree detection tasks.

    This class handles loading, preprocessing, and transforming tree detection
    datasets. Images and annotations are loaded and preprocessed for deep learning
    models. It supports geometric and color augmentations if transforms are provided,
    and optional scaling of pixel values.

    :ivar images_dir: Path to the directory containing image files.
    :type images_dir: str
    :ivar annotations: Parsed annotations for the dataset, loaded from the JSON file.
    :type annotations: dict
    :ivar image_ids: List of image IDs corresponding to the keys in the annotations.
    :type image_ids: list
    :ivar transforms: A callable for data augmentation and transformations. It must support
        the `image`, `bboxes`, and `labels` keys for input and output.
    :type transforms: callable, optional
    :ivar do_scale: Whether to scale image pixel values to the range 0-255.
    :type do_scale: bool
    """
    def __init__(self, images_dir, annotations_path, transforms=None, do_scale=True):
        self.images_dir = images_dir
        self.transforms = transforms
        self.do_scale = do_scale

        with open(annotations_path, "r") as f:
            self.annotations = json.load(f)
        self.image_ids = list(self.annotations.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        ann = self.annotations[image_id]
        image_path = os.path.join(self.images_dir, ann["file_name"])

        with rasterio.open(image_path) as src:
            image_array = src.read()

        image_array = np.transpose(image_array, (1, 2, 0))

        if self.do_scale:
            data_min = image_array.min()
            data_max = image_array.max()
            if data_max > data_min:
                image_array = 255.0 * (image_array - data_min) / (data_max - data_min + 1e-8)
            image_array = np.clip(image_array, 0, 255).astype(np.uint8)

        boxes = ann["boxes"]
        labels = ann["labels"]

        if self.transforms is not None:
            augmented = self.transforms(
                image=image_array,
                bboxes=boxes,
                labels=labels
            )
            image_array = augmented["image"]
            boxes = augmented["bboxes"]
            labels = augmented["labels"]

        image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1)
        boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
        labels_tensor = torch.tensor(labels, dtype=torch.int64)

        target = {"boxes": boxes_tensor, "labels": labels_tensor}
        return image_tensor, target

obia.detection.models

obia.detection.models

RetinaNet-based Detection Model (Modified for N-Channel Input)

Allows specifying 'in_channels' for multi-band data. By default, the pretrained backbone is for 3 channels. We replace the first conv layer to match in_channels, partially or fully copying weights for the first 3 channels if in_channels >= 3.

build_detection_model(num_classes=2, in_channels=3)

Builds a RetinaNet model with optional adjustment for N-band input.

Parameters:

Name Type Description Default
num_classes int

Number of classes (including background if you prefer).

2
in_channels int

Number of input channels (e.g., 4 for RGB+CHM).

3

Returns:

Name Type Description
model Module

The modified RetinaNet model.

Source code in obia/detection/models.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def build_detection_model(num_classes=2, in_channels=3):
    """
    Builds a RetinaNet model with optional adjustment for N-band input.

    Args:
        num_classes (int): Number of classes (including background if you prefer).
        in_channels (int): Number of input channels (e.g., 4 for RGB+CHM).

    Returns:
        model (nn.Module): The modified RetinaNet model.
    """
    model = retinanet_resnet50_fpn(weights=None)
    model.transform.image_mean = [0.0] * 10
    model.transform.image_std = [1.0] * 10
    model.train()

    anchor_generator = model.anchor_generator
    out_channels = model.backbone.out_channels
    num_anchors = anchor_generator.num_anchors_per_location()[0]

    model.head.classification_head = RetinaNetClassificationHead(
        out_channels,
        num_anchors,
        num_classes
    )

    if in_channels != 3:
        old_conv = model.backbone.body.conv1
        new_conv = nn.Conv2d(
            in_channels,
            old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=False
        )

        with torch.no_grad():
            n_to_copy = min(in_channels, 3)
            new_conv.weight[:, :n_to_copy, :, :] = old_conv.weight[:, :n_to_copy, :, :]

        model.backbone.body.conv1 = new_conv

    return model

obia.detection.train

obia.detection.train

Training Script for a RetinaNet-based Detection Model

train_model(model, train_loader, num_epochs, device='cpu')

Trains the RetinaNet detection model.

Parameters:

Name Type Description Default
model Module

The detection model (e.g., from build_detection_model).

required
train_loader DataLoader

DataLoader for training data.

required
num_epochs int

Number of epochs to train.

required
device str

Device to use ("cpu", "cuda", or "mps").

'cpu'

Returns:

Name Type Description
model Module

Trained model.

Source code in obia/detection/train.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def train_model(model, train_loader, num_epochs, device="cpu"):
    """
    Trains the RetinaNet detection model.

    Args:
        model (nn.Module): The detection model (e.g., from build_detection_model).
        train_loader (DataLoader): DataLoader for training data.
        num_epochs (int): Number of epochs to train.
        device (str): Device to use ("cpu", "cuda", or "mps").

    Returns:
        model (nn.Module): Trained model.
    """
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(params, lr=1e-4)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            total_loss += losses.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

    return model

obia.detection.predict

obia.detection.predict

Object Detection Prediction Script (Multi-band + 0..255 Scaling)

Provides a function predict() for running inference with a custom RetinaNet-based model on an N-band raster. Uses rasterio to read the data, scales each band to [0..255], then feeds to the model.

predict(model, image_path, device='cpu', score_threshold=0.5)

Parameters:

Name Type Description Default
model Module

Trained RetinaNet model (with in_channels matching your data).

required
image_path str

Path to the multi-band raster (GeoTIFF, etc.).

required
device str

"cpu", "cuda", or "mps".

'cpu'
score_threshold float

Minimum confidence for detection.

0.5

Returns:

Type Description
dict

dict with { "boxes": nd.array, "scores": nd.array, "labels": nd.array }

Source code in obia/detection/predict.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def predict(model, image_path, device="cpu", score_threshold=0.5) -> dict:
    """
    Args:
        model (nn.Module): Trained RetinaNet model (with in_channels matching your data).
        image_path (str): Path to the multi-band raster (GeoTIFF, etc.).
        device (str): "cpu", "cuda", or "mps".
        score_threshold (float): Minimum confidence for detection.

    Returns:
        dict with { "boxes": nd.array, "scores": nd.array, "labels": nd.array }
    """
    with rasterio.open(image_path) as src:
        image_array = src.read()

    image_array = np.transpose(image_array, (1, 2, 0))

    data_min = image_array.min()
    data_max = image_array.max()
    if data_max > data_min:
        image_array = 255 * (image_array - data_min) / (data_max - data_min + 1e-8)
    image_array = np.clip(image_array, 0, 255).astype(np.uint8)

    image_tensor = torch.tensor(image_array, dtype=torch.float32).permute(2, 0, 1)

    model.to(device)
    model.eval()

    with torch.no_grad():
        outputs = model([image_tensor.to(device)])

    boxes = outputs[0]["boxes"].cpu().numpy()
    scores = outputs[0]["scores"].cpu().numpy()
    labels = outputs[0]["labels"].cpu().numpy()

    keep = scores >= score_threshold
    boxes = boxes[keep]
    scores = scores[keep]
    labels = labels[keep]

    return {
        "boxes": boxes,
        "scores": scores,
        "labels": labels
    }

obia.detection.utils

obia.detection.utils

Transforms, Collate Function, and Utility Helpers

Provides: - get_transforms(): Albumentations pipelines for train/val - collate_fn(): custom collate for object detection - calculate_iou(): compute IoU between two boxes - visualize_predictions(): draw detection boxes and scores on an image

calculate_iou(box1, box2)

Compute Intersection over Union (IoU) for two bounding boxes. Boxes assumed in format [x_min, y_min, x_max, y_max].

Source code in obia/detection/utils.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def calculate_iou(box1, box2):
    """
    Compute Intersection over Union (IoU) for two bounding boxes.
    Boxes assumed in format [x_min, y_min, x_max, y_max].
    """
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    inter_width = max(0, x2 - x1)
    inter_height = max(0, y2 - y1)
    inter_area = inter_width * inter_height

    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])

    iou = inter_area / float(box1_area + box2_area - inter_area + 1e-6)
    return iou

collate_fn(batch)

Custom collate function for object detection in PyTorch. Returns lists of images and targets.

Source code in obia/detection/utils.py
50
51
52
53
54
55
56
57
58
59
60
def collate_fn(batch):
    """
    Custom collate function for object detection in PyTorch.
    Returns lists of images and targets.
    """
    images = []
    targets = []
    for img, tgt in batch:
        images.append(img)
        targets.append(tgt)
    return images, targets

get_transforms(train=True)

Returns Albumentations transforms for bounding-box tasks. If you have more than 3 channels, consider removing any 3-channel-specific Normalization or adjusting mean/std.

Source code in obia/detection/utils.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def get_transforms(train=True):
    """
    Returns Albumentations transforms for bounding-box tasks.
    If you have more than 3 channels, consider removing any
    3-channel-specific Normalization or adjusting mean/std.
    """
    if train:
        return A.Compose(
            [
                A.RandomRotate90(p=0.5),
                A.Flip(p=0.5),
            ],
            bbox_params=A.BboxParams(
                format='pascal_voc',
                label_fields=['labels'],
                min_area=0,
                min_visibility=0.0
            )
        )
    else:
        return A.Compose(
            [
                # A.Normalize(mean=(...), std=(...)),
            ],
            bbox_params=A.BboxParams(
                format='pascal_voc',
                label_fields=['labels'],
                min_area=0,
                min_visibility=0.0
            )
        )

visualize_predictions(image_path, detection_output, score_threshold=0.0)

Draws bounding boxes (and scores) on an image and displays it using matplotlib.

Parameters:

Name Type Description Default
image_path str

Path to the image file.

required
detection_output dict

Must contain "boxes", "scores", "labels".

required
score_threshold float

Only visualize boxes with score >= threshold.

0.0
Source code in obia/detection/utils.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def visualize_predictions(image_path, detection_output, score_threshold=0.0):
    """
    Draws bounding boxes (and scores) on an image and displays it using matplotlib.

    Args:
        image_path (str): Path to the image file.
        detection_output (dict): Must contain "boxes", "scores", "labels".
        score_threshold (float): Only visualize boxes with score >= threshold.
    """
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Could not read image from {image_path}")
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    boxes = detection_output["boxes"]
    scores = detection_output["scores"]
    labels = detection_output["labels"]

    for box, score, label in zip(boxes, scores, labels):
        if score >= score_threshold:
            x1, y1, x2, y2 = [int(coord) for coord in box]
            cv2.rectangle(
                image_rgb,
                (x1, y1),
                (x2, y2),
                (0, 255, 0), 2
            )
            text = f"Class {label}, {score:.2f}"
            cv2.putText(
                image_rgb,
                text,
                (x1, y1 - 5),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                (0, 255, 0),
                1
            )

    plt.figure(figsize=(8, 8))
    plt.imshow(image_rgb)
    plt.axis("off")
    plt.show()