ultralytics/callbacks/paths.py
Fatih Akyon c87765a383
feat: accept data= in paths.patch_resume
Ultralytics check_resume (trainer.py:841) restores the checkpoint's data path verbatim and does not honor caller overrides; cross-host resumes where the dataset lives at a different mount point (e.g. ultra5 NFS outage) previously needed a manual torch.load/save dance to rewrite train_args. Mirrors the existing name/device override branches so one helper call covers all four non-whitelisted fields (project, name, save_dir, data).
2026-04-18 11:31:06 -05:00

78 lines
3.7 KiB
Python

"""Central config for W&B project + local SSD / NFS mirror roots.
Why this module exists:
- Ultralytics' built-in W&B callback (ultralytics/utils/callbacks/wb.py:138) does
``project=str(trainer.args.project).replace("/", "-")``. Passing an absolute path as
``project=`` (e.g. ``/home/fatih/runs/yolo-next-encoder``) mangles the W&B project to
``-home-fatih-runs-yolo-next-encoder``. The fix is to pass a clean W&B project name as
``project=`` and an absolute local path as ``save_dir=``; Ultralytics' ``get_save_dir``
(ultralytics/cfg/__init__.py:395) honors ``save_dir`` verbatim, bypassing the
project+name join.
- Ultralytics' ``check_resume`` (ultralytics/engine/trainer.py:841) overwrites most
caller-supplied args with the checkpoint's ``train_args`` (whitelist: imgsz, batch,
device, close_mosaic, augmentations, save_period, workers, cache, patience, time,
freeze, val, plots). For cross-machine or relocated resumes, project/name/save_dir
must be patched on the checkpoint itself, not on the caller side.
Callers use ``run_paths(name)`` for fresh runs and ``patch_resume(ckpt)`` for resumes.
"""
from __future__ import annotations
from pathlib import Path
WANDB_PROJECT = "yolo-next-encoder"
LOCAL_ROOT = Path("/home/fatih/runs/yolo-next-encoder")
NFS_MIRROR_ROOT = Path("/data/shared-datasets/fatih-runs/classify/yolo-next-encoder")
SYNC_INTERVAL_SEC = 600
assert LOCAL_ROOT.is_absolute() and str(LOCAL_ROOT).startswith("/home/"), (
f"LOCAL_ROOT must be absolute and under /home/ to decouple from NFS, got {LOCAL_ROOT}"
)
def run_paths(name: str, exist_ok: bool = False) -> dict:
"""Return W&B project + absolute local save_dir kwargs for ``model.train``.
Args:
name (str): Run name, used as W&B display name and ``save_dir`` leaf.
exist_ok (bool, optional): Allow ``save_dir`` to already exist.
Returns:
(dict): Kwargs with ``project``, ``name``, ``save_dir``, ``exist_ok``.
"""
return dict(project=WANDB_PROJECT, name=name, save_dir=str(LOCAL_ROOT / name), exist_ok=exist_ok)
def patch_resume(ckpt_path, name: str | None = None, device=None, data: str | None = None) -> str:
"""Rewrite a checkpoint's ``train_args`` to clean W&B project + absolute local save_dir, in place.
Needed because Ultralytics' ``check_resume`` restores project/name/save_dir/data from the checkpoint, not caller
kwargs; without this, a resume on a different machine or save_dir inherits whatever the original trainer baked in.
Args:
ckpt_path (str | Path): Checkpoint to patch (local or NFS path).
name (str, optional): Override run name. Defaults to the checkpoint's existing name.
device (int | str, optional): Override CUDA device (whitelisted for resume, e.g. when the new machine exposes
the target physical GPU as a different CUDA index).
data (str, optional): Override dataset path, e.g. when the resuming host mounts the dataset at a different
location (``data`` is NOT in ``check_resume``'s override whitelist so it must be baked into the checkpoint).
Returns:
(str): Absolute path of the patched checkpoint (same as input, for chaining).
"""
import torch
ckpt_path = Path(ckpt_path).expanduser().resolve()
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
train_args = ckpt.setdefault("train_args", {}) or {}
run_name = name or train_args.get("name") or ckpt_path.parent.parent.name
train_args.update(project=WANDB_PROJECT, name=run_name, save_dir=str(LOCAL_ROOT / run_name), exist_ok=True)
if device is not None:
train_args["device"] = device
if data is not None:
train_args["data"] = data
ckpt["train_args"] = train_args
torch.save(ckpt, ckpt_path)
return str(ckpt_path)