Source Code added

This commit is contained in:
Fr4nz D13trich 2026-02-02 15:06:40 +01:00
parent 800376eafd
commit 9efa9bc6dd
3912 changed files with 754770 additions and 2 deletions

View file

@ -0,0 +1,48 @@
from typing import Any
from immich_ml.models.base import InferenceModel
from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
from immich_ml.models.clip.visual import OpenClipVisualEncoder
from immich_ml.models.ocr.detection import TextDetector
from immich_ml.models.ocr.recognition import TextRecognizer
from immich_ml.schemas import ModelSource, ModelTask, ModelType
from .constants import get_model_source
from .facial_recognition.detection import FaceDetector
from .facial_recognition.recognition import FaceRecognizer
def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
source = get_model_source(model_name)
match source, model_type, model_task:
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
return OpenClipVisualEncoder
case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return OpenClipTextualEncoder
case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return MClipTextualEncoder
case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
return FaceDetector
case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
return FaceRecognizer
case ModelSource.PADDLE, ModelType.DETECTION, ModelTask.OCR:
return TextDetector
case ModelSource.PADDLE, ModelType.RECOGNITION, ModelTask.OCR:
return TextRecognizer
case _:
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")
def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
return get_model_class(model_name, model_type, model_task).depends

View file

@ -0,0 +1,176 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any, ClassVar
from huggingface_hub import snapshot_download
import immich_ml.sessions.ann.loader
import immich_ml.sessions.rknn as rknn
from immich_ml.sessions.ort import OrtSession
from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from ..sessions.ann import AnnSession
class InferenceModel(ABC):
depends: ClassVar[list[ModelIdentity]]
identity: ClassVar[ModelIdentity]
def __init__(
self,
model_name: str,
cache_dir: Path | str | None = None,
model_format: ModelFormat | None = None,
session: ModelSession | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = session is not None
self.load_attempts = 0
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
self.model_format = model_format if model_format is not None else self._model_format_default
if session is not None:
self.session = session
def download(self) -> None:
if not self.cached:
model_type = self.model_type.replace("-", " ")
log.info(f"Downloading {model_type} model '{self.model_name}' to {self.model_path}. This may take a while.")
self._download()
def load(self) -> None:
if self.loaded:
return
self.load_attempts += 1
self.download()
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
self.session = self._load()
self.loaded = True
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
self.load()
if model_kwargs:
self.configure(**model_kwargs)
return self._predict(*inputs)
@abstractmethod
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
def configure(self, **kwargs: Any) -> None:
pass
def _download(self) -> None:
ignored_patterns: dict[ModelFormat, list[str]] = {
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
ModelFormat.ARMNN: ["*.rknn"],
ModelFormat.RKNN: ["*.armnn"],
}
snapshot_download(
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
local_dir=self.cache_dir,
ignore_patterns=ignored_patterns.get(self.model_format, []),
)
def _load(self) -> ModelSession:
return self._make_session(self.model_path)
def clear_cache(self) -> None:
if not self.cache_dir.exists():
log.warning(
f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
)
return
if not rmtree.avoids_symlink_attacks:
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")
if self.cache_dir.is_dir():
log.info(f"Cleared cache directory for model '{self.model_name}'.")
rmtree(self.cache_dir)
else:
log.warning(
(
f"Encountered file instead of directory at cache path "
f"for '{self.model_name}'. Removing file and replacing with a directory."
),
)
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> ModelSession:
if not model_path.is_file():
raise FileNotFoundError(f"Model file not found: {model_path}")
match model_path.suffix:
case ".armnn":
session: ModelSession = AnnSession(model_path)
case ".onnx":
session = OrtSession(model_path)
case ".rknn":
session = rknn.RknnSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
def model_path_for_format(self, model_format: ModelFormat) -> Path:
model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
if model_path_prefix:
return self.model_dir / model_path_prefix / f"model.{model_format}"
return self.model_dir / f"model.{model_format}"
@property
def model_dir(self) -> Path:
return self.cache_dir / self.model_type.value
@property
def model_path(self) -> Path:
return self.model_path_for_format(self.model_format)
@property
def model_task(self) -> ModelTask:
return self.identity[1]
@property
def model_type(self) -> ModelType:
return self.identity[0]
@property
def cache_dir(self) -> Path:
return self._cache_dir
@cache_dir.setter
def cache_dir(self, cache_dir: Path) -> None:
self._cache_dir = cache_dir
@property
def _cache_dir_default(self) -> Path:
return settings.cache_folder / self.model_task.value / self.model_name
@property
def cached(self) -> bool:
return self.model_path.is_file()
@property
def model_format(self) -> ModelFormat:
return self._model_format
@model_format.setter
def model_format(self, model_format: ModelFormat) -> None:
log.debug(f"Setting model format to {model_format}")
self._model_format = model_format
@property
def _model_format_default(self) -> ModelFormat:
if rknn.is_available:
return ModelFormat.RKNN
elif immich_ml.sessions.ann.loader.is_available and settings.ann:
return ModelFormat.ARMNN
else:
return ModelFormat.ONNX

View file

@ -0,0 +1,60 @@
from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import TimingPlugin
from immich_ml.models import from_model_type
from immich_ml.models.base import InferenceModel
from ..schemas import ModelTask, ModelType, has_profiling
class ModelCache:
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
def __init__(
self,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""
plugins = []
if profiling:
plugins.append(TimingPlugin())
self.should_revalidate = revalidate
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
async def get(
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
) -> InferenceModel:
key = f"{model_name}{model_type}{model_task}"
async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_name, model_type, model_task, **model_kwargs)
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.should_revalidate:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model
async def get_profiling(self) -> dict[str, float] | None:
if not has_profiling(self.cache):
return None
return self.cache.profiling
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)

View file

@ -0,0 +1,120 @@
import json
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any
import numpy as np
from numpy.typing import NDArray
from tokenizers import Encoding, Tokenizer
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.models.constants import WEBLATE_TO_FLORES200
from immich_ml.models.transforms import clean_text, serialize_np_array
from immich_ml.schemas import ModelSession, ModelTask, ModelType
class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
def _predict(self, inputs: str, language: str | None = None) -> str:
tokens = self.tokenize(inputs, language=language)
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
return serialize_np_array(res)
def _load(self) -> ModelSession:
session = super()._load()
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
self.tokenizer = self._load_tokenizer()
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
self.is_nllb = self.model_name.startswith("nllb")
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
return session
@abstractmethod
def _load_tokenizer(self) -> Tokenizer:
pass
@abstractmethod
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
pass
@property
def model_cfg_path(self) -> Path:
return self.cache_dir / "config.json"
@property
def tokenizer_file_path(self) -> Path:
return self.model_dir / "tokenizer.json"
@property
def tokenizer_cfg_path(self) -> Path:
return self.model_dir / "tokenizer_config.json"
@cached_property
def model_cfg(self) -> dict[str, Any]:
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
return model_cfg
@property
def text_cfg(self) -> dict[str, Any]:
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
return text_cfg
@cached_property
def tokenizer_file(self) -> dict[str, Any]:
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
return tokenizer_file
@cached_property
def tokenizer_cfg(self) -> dict[str, Any]:
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
return tokenizer_cfg
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
def _load_tokenizer(self) -> Tokenizer:
context_length: int = self.text_cfg.get("context_length", 77)
pad_token: str = self.tokenizer_cfg["pad_token"]
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
pad_id: int = tokenizer.token_to_id(pad_token)
tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
tokenizer.enable_truncation(max_length=context_length)
return tokenizer
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
if self.is_nllb and language is not None:
flores_code = WEBLATE_TO_FLORES200.get(language)
if flores_code is None:
no_country = language.split("-")[0]
flores_code = WEBLATE_TO_FLORES200.get(no_country)
if flores_code is None:
log.warning(f"Language '{language}' not found, defaulting to 'en'")
flores_code = "eng_Latn"
text = f"{flores_code}{text}"
tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)}
class MClipTextualEncoder(OpenClipTextualEncoder):
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text)
return {
"input_ids": np.array([tokens.ids], dtype=np.int32),
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
}

View file

@ -0,0 +1,77 @@
import json
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import (
crop_pil,
decode_pil,
get_pil_resampling,
normalize,
resize_pil,
serialize_np_array,
to_numpy,
)
from immich_ml.schemas import ModelSession, ModelTask, ModelType
class BaseCLIPVisualEncoder(InferenceModel):
depends = []
identity = (ModelType.VISUAL, ModelTask.SEARCH)
def _predict(self, inputs: Image.Image | bytes) -> str:
image = decode_pil(inputs)
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
return serialize_np_array(res)
@abstractmethod
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
pass
@property
def model_cfg_path(self) -> Path:
return self.cache_dir / "config.json"
@property
def preprocess_cfg_path(self) -> Path:
return self.model_dir / "preprocess_cfg.json"
@cached_property
def model_cfg(self) -> dict[str, Any]:
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
return model_cfg
@cached_property
def preprocess_cfg(self) -> dict[str, Any]:
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
return preprocess_cfg
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
def _load(self) -> ModelSession:
size: list[int] | int = self.preprocess_cfg["size"]
self.size = size[0] if isinstance(size, list) else size
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
return super()._load()
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
image = resize_pil(image, self.size)
image = crop_pil(image, self.size)
image_np = to_numpy(image)
image_np = normalize(image_np, self.mean, self.std)
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}

View file

@ -0,0 +1,178 @@
from immich_ml.config import clean_name
from immich_ml.schemas import ModelSource
_OPENCLIP_MODELS = {
"RN101__openai",
"RN101__yfcc15m",
"RN50__cc12m",
"RN50__openai",
"RN50__yfcc15m",
"RN50x16__openai",
"RN50x4__openai",
"RN50x64__openai",
"ViT-B-16-SigLIP-256__webli",
"ViT-B-16-SigLIP-384__webli",
"ViT-B-16-SigLIP-512__webli",
"ViT-B-16-SigLIP-i18n-256__webli",
"ViT-B-16-SigLIP__webli",
"ViT-B-16-plus-240__laion400m_e31",
"ViT-B-16-plus-240__laion400m_e32",
"ViT-B-16__laion400m_e31",
"ViT-B-16__laion400m_e32",
"ViT-B-16__openai",
"ViT-B-32__laion2b-s34b-b79k",
"ViT-B-32__laion2b_e16",
"ViT-B-32__laion400m_e31",
"ViT-B-32__laion400m_e32",
"ViT-B-32__openai",
"ViT-H-14-378-quickgelu__dfn5b",
"ViT-H-14-quickgelu__dfn5b",
"ViT-H-14__laion2b-s32b-b79k",
"ViT-L-14-336__openai",
"ViT-L-14-quickgelu__dfn2b",
"ViT-L-14__laion2b-s32b-b82k",
"ViT-L-14__laion400m_e31",
"ViT-L-14__laion400m_e32",
"ViT-L-14__openai",
"ViT-L-16-SigLIP-256__webli",
"ViT-L-16-SigLIP-384__webli",
"ViT-SO400M-14-SigLIP-384__webli",
"ViT-g-14__laion2b-s12b-b42k",
"XLM-Roberta-Base-ViT-B-32__laion5b_s13b_b90k",
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
"nllb-clip-base-siglip__mrl",
"nllb-clip-base-siglip__v1",
"nllb-clip-large-siglip__mrl",
"nllb-clip-large-siglip__v1",
"ViT-B-16-SigLIP2__webli",
"ViT-B-32-SigLIP2-256__webli",
"ViT-L-16-SigLIP2-256__webli",
"ViT-L-16-SigLIP2-384__webli",
"ViT-L-16-SigLIP2-512__webli",
"ViT-SO400M-14-SigLIP2-378__webli",
"ViT-SO400M-14-SigLIP2__webli",
"ViT-SO400M-16-SigLIP2-256__webli",
"ViT-SO400M-16-SigLIP2-384__webli",
"ViT-SO400M-16-SigLIP2-512__webli",
"ViT-gopt-16-SigLIP2-256__webli",
"ViT-gopt-16-SigLIP2-384__webli",
}
_MCLIP_MODELS = {
"LABSE-Vit-L-14",
"XLM-Roberta-Large-Vit-B-16Plus",
"XLM-Roberta-Large-Vit-B-32",
"XLM-Roberta-Large-Vit-L-14",
}
_INSIGHTFACE_MODELS = {
"antelopev2",
"buffalo_s",
"buffalo_m",
"buffalo_l",
}
_PADDLE_MODELS = {
"PP-OCRv5_server",
"PP-OCRv5_mobile",
"CH__PP-OCRv5_server",
"CH__PP-OCRv5_mobile",
"EL__PP-OCRv5_mobile",
"EN__PP-OCRv5_mobile",
"ESLAV__PP-OCRv5_mobile",
"KOREAN__PP-OCRv5_mobile",
"LATIN__PP-OCRv5_mobile",
"TH__PP-OCRv5_mobile",
}
SUPPORTED_PROVIDERS = [
"CUDAExecutionProvider",
"ROCMExecutionProvider",
"OpenVINOExecutionProvider",
"CoreMLExecutionProvider",
"CPUExecutionProvider",
]
RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
WEBLATE_TO_FLORES200 = {
"af": "afr_Latn",
"ar": "arb_Arab",
"az": "azj_Latn",
"be": "bel_Cyrl",
"bg": "bul_Cyrl",
"ca": "cat_Latn",
"cs": "ces_Latn",
"da": "dan_Latn",
"de": "deu_Latn",
"el": "ell_Grek",
"en": "eng_Latn",
"es": "spa_Latn",
"et": "est_Latn",
"fa": "pes_Arab",
"fi": "fin_Latn",
"fr": "fra_Latn",
"he": "heb_Hebr",
"hi": "hin_Deva",
"hr": "hrv_Latn",
"hu": "hun_Latn",
"hy": "hye_Armn",
"id": "ind_Latn",
"it": "ita_Latn",
"ja": "jpn_Hira",
"kmr": "kmr_Latn",
"ko": "kor_Hang",
"lb": "ltz_Latn",
"lt": "lit_Latn",
"lv": "lav_Latn",
"mfa": "zsm_Latn",
"mk": "mkd_Cyrl",
"mn": "khk_Cyrl",
"mr": "mar_Deva",
"ms": "zsm_Latn",
"nb-NO": "nob_Latn",
"nn": "nno_Latn",
"nl": "nld_Latn",
"pl": "pol_Latn",
"pt-BR": "por_Latn",
"pt": "por_Latn",
"ro": "ron_Latn",
"ru": "rus_Cyrl",
"sk": "slk_Latn",
"sl": "slv_Latn",
"sr-Cyrl": "srp_Cyrl",
"sv": "swe_Latn",
"ta": "tam_Taml",
"te": "tel_Telu",
"th": "tha_Thai",
"tr": "tur_Latn",
"uk": "ukr_Cyrl",
"ur": "urd_Arab",
"vi": "vie_Latn",
"zh-CN": "zho_Hans",
"zh-Hans": "zho_Hans",
"zh-TW": "zho_Hant",
}
def get_model_source(model_name: str) -> ModelSource | None:
cleaned_name = clean_name(model_name)
if cleaned_name in _INSIGHTFACE_MODELS:
return ModelSource.INSIGHTFACE
if cleaned_name in _MCLIP_MODELS:
return ModelSource.MCLIP
if cleaned_name in _OPENCLIP_MODELS:
return ModelSource.OPENCLIP
if cleaned_name in _PADDLE_MODELS:
return ModelSource.PADDLE
return None

View file

@ -0,0 +1,41 @@
from typing import Any
import numpy as np
from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_cv2
from immich_ml.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
class FaceDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, **model_kwargs)
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
return session
def _predict(self, inputs: NDArray[np.uint8] | bytes) -> FaceDetectionOutput:
inputs = decode_cv2(inputs)
bboxes, landmarks = self._detect(inputs)
return {
"boxes": bboxes[:, :4].round(),
"scores": bboxes[:, 4],
"landmarks": landmarks,
}
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore
def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)

View file

@ -0,0 +1,92 @@
from pathlib import Path
from typing import Any
import numpy as np
import onnx
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX
from insightface.utils.face_align import norm_crop
from numpy.typing import NDArray
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from PIL import Image
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_cv2, serialize_np_array
from immich_ml.schemas import (
FaceDetectionOutput,
FacialRecognitionOutput,
ModelFormat,
ModelSession,
ModelTask,
ModelType,
)
class FaceRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs)
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if (not self.batch_size or self.batch_size > 1) and str(session.get_inputs()[0].shape[0]) != "batch":
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
self.model_path_for_format(ModelFormat.ONNX).as_posix(),
session=session,
)
return session
def _predict(
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput
) -> FacialRecognitionOutput:
if faces["boxes"].shape[0] == 0:
return []
inputs = decode_cv2(inputs)
cropped_faces = self._crop(inputs, faces)
embeddings = self._predict_batch(cropped_faces)
return self.postprocess(faces, embeddings)
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
if not self.batch_size or len(cropped_faces) <= self.batch_size:
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
return embeddings
batch_embeddings: list[NDArray[np.float32]] = []
for i in range(0, len(cropped_faces), self.batch_size):
batch_embeddings.append(self.model.get_feat(cropped_faces[i : i + self.batch_size]))
return np.concatenate(batch_embeddings, axis=0)
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
return [
{
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"embedding": serialize_np_array(embedding),
"score": score,
}
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
]
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
def _add_batch_axis(self, model_path: Path) -> None:
log.debug(f"Adding batch axis to model {model_path}")
proto = onnx.load(model_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
onnx.save(updated_proto, model_path)
@property
def _batch_size_default(self) -> int | None:
providers = ort.get_available_providers()
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1

View file

@ -0,0 +1,125 @@
from typing import Any
import cv2
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from rapidocr.ch_ppocr_det.utils import DBPostProcess
from rapidocr.inference_engine.base import FileInfo, InferSession
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import TextDetectionOutput
class TextDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
self.max_resolution = 736
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
self._empty: TextDetectionOutput = {
"boxes": np.empty(0, dtype=np.float32),
"scores": np.empty(0, dtype=np.float32),
}
self.postprocess = DBPostProcess(
thresh=0.3,
box_thresh=model_kwargs.get("minScore", 0.5),
max_candidates=1000,
unclip_ratio=1.6,
use_dilation=True,
score_mode="fast",
)
def _download(self) -> None:
model_info = InferSession.get_model_url(
FileInfo(
engine_type=EngineType.ONNXRUNTIME,
ocr_version=OCRVersion.PPOCRV5,
task_type=TaskType.DET,
lang_type=LangDet.CH,
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
)
)
download_params = DownloadFileInput(
file_url=model_info["model_dir"],
sha256=model_info["SHA256"],
save_path=self.model_path,
logger=log,
)
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
# TODO: support other runtime sessions
return OrtSession(self.model_path)
# partly adapted from RapidOCR
def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
w, h = inputs.size
if w < 32 or h < 32:
return self._empty
out = self.session.run(None, {"x": self._transform(inputs)})[0]
boxes, scores = self.postprocess(out, (h, w))
if len(boxes) == 0:
return self._empty
return {
"boxes": self.sorted_boxes(boxes),
"scores": np.array(scores, dtype=np.float32),
}
# adapted from RapidOCR
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
if img.height < img.width:
ratio = float(self.max_resolution) / img.height
else:
ratio = float(self.max_resolution) / img.width
ratio = min(ratio, 1.0)
resize_h = int(img.height * ratio)
resize_w = int(img.width * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
img_np -= self.mean
img_np *= self.std_inv
img_np = np.transpose(img_np, (2, 0, 1))
return np.expand_dims(img_np, axis=0)
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
if len(dt_boxes) == 0:
return dt_boxes
# Sort by y, then identify lines, then sort by (line, x)
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
sorted_y = dt_boxes[y_order, 0, 1]
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
line_ids[0] = 0
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
# Create composite sort key for final ordering
# Shift line_ids by large factor, add x for tie-breaking
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
final_order = np.argsort(sort_key, kind="stable")
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
return sorted_boxes
def configure(self, **kwargs: Any) -> None:
if (max_resolution := kwargs.get("maxResolution")) is not None:
self.max_resolution = max_resolution
if (min_score := kwargs.get("minScore")) is not None:
self.postprocess.box_thresh = min_score
if (score_mode := kwargs.get("scoreMode")) is not None:
self.postprocess.score_mode = score_mode

View file

@ -0,0 +1,153 @@
from typing import Any
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from rapidocr.ch_ppocr_rec import TextRecInput
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
from rapidocr.inference_engine.base import FileInfo, InferSession
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
from rapidocr.utils.typings import EngineType, LangRec, OCRVersion, TaskType
from rapidocr.utils.typings import ModelType as RapidModelType
from rapidocr.utils.vis_res import VisRes
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import pil_to_cv2
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
class TextRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.OCR)]
identity = (ModelType.RECOGNITION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
self.min_score = model_kwargs.get("minScore", 0.9)
self._empty: TextRecognitionOutput = {
"box": np.empty(0, dtype=np.float32),
"boxScore": np.empty(0, dtype=np.float32),
"text": [],
"textScore": np.empty(0, dtype=np.float32),
}
VisRes.__init__ = lambda self, **kwargs: None # pyright: ignore[reportAttributeAccessIssue]
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
def _download(self) -> None:
model_info = InferSession.get_model_url(
FileInfo(
engine_type=EngineType.ONNXRUNTIME,
ocr_version=OCRVersion.PPOCRV5,
task_type=TaskType.REC,
lang_type=self.language,
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
)
)
download_params = DownloadFileInput(
file_url=model_info["model_dir"],
sha256=model_info["SHA256"],
save_path=self.model_path,
logger=log,
)
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
# TODO: support other runtimes
session = OrtSession(self.model_path)
self.model = RapidTextRecognizer(
OcrOptions(
session=session.session,
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
rec_img_shape=(3, 48, 320),
lang_type=self.language,
)
)
return session
def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
boxes, box_scores = texts["boxes"], texts["scores"]
if boxes.shape[0] == 0:
return self._empty
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
if rec.txts is None:
return self._empty
boxes[:, :, 0] /= img.width
boxes[:, :, 1] /= img.height
text_scores = np.array(rec.scores)
valid_text_score_idx = text_scores > self.min_score
valid_score_idx_list = valid_text_score_idx.tolist()
return {
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
"boxScore": box_scores[valid_text_score_idx],
"textScore": text_scores[valid_text_score_idx],
}
def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]:
img_crop_width = np.maximum(
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
).astype(np.int32)
img_crop_height = np.maximum(
np.linalg.norm(boxes[:, 0] - boxes[:, 3], axis=1), np.linalg.norm(boxes[:, 1] - boxes[:, 2], axis=1)
).astype(np.int32)
pts_std = np.zeros((img_crop_width.shape[0], 4, 2), dtype=np.float32)
pts_std[:, 1:3, 0] = img_crop_width[:, None]
pts_std[:, 2:4, 1] = img_crop_height[:, None]
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1)
all_coeffs = self._get_perspective_transform(pts_std, boxes)
imgs: list[NDArray[np.uint8]] = []
for coeffs, dst_size in zip(all_coeffs, img_crop_sizes):
dst_img = img.transform(
size=tuple(dst_size),
method=Image.Transform.PERSPECTIVE,
data=tuple(coeffs),
resample=Image.Resampling.BICUBIC,
)
dst_width, dst_height = dst_img.size
if dst_height * 1.0 / dst_width >= 1.5:
dst_img = dst_img.rotate(90, expand=True)
imgs.append(pil_to_cv2(dst_img))
return imgs
def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]:
N = src.shape[0]
x, y = src[:, :, 0], src[:, :, 1]
u, v = dst[:, :, 0], dst[:, :, 1]
A = np.zeros((N, 8, 9), dtype=np.float32)
# Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
A[:, ::2, 0] = x
A[:, ::2, 1] = y
A[:, ::2, 2] = 1
A[:, ::2, 6] = -u * x
A[:, ::2, 7] = -u * y
A[:, ::2, 8] = -u
# Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
A[:, 1::2, 3] = x
A[:, 1::2, 4] = y
A[:, 1::2, 5] = 1
A[:, 1::2, 6] = -v * x
A[:, 1::2, 7] = -v * y
A[:, 1::2, 8] = -v
# Solve using SVD for all matrices at once
_, _, Vt = np.linalg.svd(A)
H = Vt[:, -1, :].reshape(N, 3, 3)
H = H / H[:, 2:3, 2:3]
# Extract the 8 coefficients for each transformation
return np.column_stack(
[H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]]
) # pyright: ignore[reportReturnType]
def configure(self, **kwargs: Any) -> None:
self.min_score = kwargs.get("minScore", self.min_score)

View file

@ -0,0 +1,27 @@
from typing import Any, Iterable
import numpy as np
import numpy.typing as npt
from rapidocr.utils.typings import EngineType, LangRec
from typing_extensions import TypedDict
class TextDetectionOutput(TypedDict):
boxes: npt.NDArray[np.float32]
scores: npt.NDArray[np.float32]
class TextRecognitionOutput(TypedDict):
box: npt.NDArray[np.float32]
boxScore: npt.NDArray[np.float32]
text: Iterable[str]
textScore: npt.NDArray[np.float32]
# RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes
class OcrOptions(dict[str, Any]):
def __init__(self, lang_type: LangRec | None = None, **options: Any) -> None:
super().__init__(**options)
self.engine_type = EngineType.ONNXRUNTIME
self.lang_type = lang_type
self.font_path = None

View file

@ -0,0 +1,80 @@
import string
from io import BytesIO
from typing import IO
import cv2
import numpy as np
import orjson
from numpy.typing import NDArray
from PIL import Image
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
def resize_pil(img: Image.Image, size: int) -> Image.Image:
if img.width < img.height:
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
else:
return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC)
# https://stackoverflow.com/a/60883103
def crop_pil(img: Image.Image, size: int) -> Image.Image:
left = int((img.size[0] / 2) - (size / 2))
upper = int((img.size[1] / 2) - (size / 2))
right = left + size
lower = upper + size
return img.crop((left, upper, right, lower))
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
def normalize(
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
) -> NDArray[np.float32]:
return (img - mean) / std
def get_pil_resampling(resample: str) -> Image.Resampling:
return _PIL_RESAMPLING_METHODS[resample.lower()]
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
if isinstance(image_bytes, Image.Image):
return image_bytes
image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
image.load()
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
match image_bytes:
case bytes() | memoryview() | bytearray():
return pil_to_cv2(decode_pil(image_bytes)) # pillow is much faster than cv2
case Image.Image():
return pil_to_cv2(image_bytes)
case _:
return image_bytes
def clean_text(text: str, canonicalize: bool = False) -> str:
text = " ".join(text.split())
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text
# this allows the client to use the array as a string without deserializing only to serialize back to a string
# TODO: use this in a less invasive way
def serialize_np_array(arr: NDArray[np.float32]) -> str:
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()