Source Code added

This commit is contained in:
Fr4nz D13trich 2026-02-02 15:06:40 +01:00
parent 800376eafd
commit 9efa9bc6dd
3912 changed files with 754770 additions and 2 deletions

View file

View file

@ -0,0 +1,57 @@
import os
import signal
import subprocess
from ipaddress import ip_address
from pathlib import Path
from .config import log, non_prefixed_settings, settings
if source_ref := os.getenv("IMMICH_SOURCE_REF"):
log.info(f"Initializing Immich ML [{source_ref}]")
else:
log.info("Initializing Immich ML")
module_dir = Path(__file__).parent
def is_ipv6(host: str) -> bool:
try:
return ip_address(host).version == 6
except ValueError:
return False
bind_host = non_prefixed_settings.immich_host
if is_ipv6(bind_host):
bind_host = f"[{bind_host}]"
bind_address = f"{bind_host}:{non_prefixed_settings.immich_port}"
try:
with subprocess.Popen(
[
"python",
"-m",
"gunicorn",
"immich_ml.main:app",
"-k",
"immich_ml.config.CustomUvicornWorker",
"-c",
module_dir / "gunicorn_conf.py",
"-b",
bind_address,
"-w",
str(settings.workers),
"-t",
str(settings.worker_timeout),
"--log-config-json",
module_dir / "log_conf.json",
"--keep-alive",
str(settings.http_keepalive_timeout_s),
"--graceful-timeout",
"10",
],
) as cmd:
cmd.wait()
except KeyboardInterrupt:
cmd.send_signal(signal.SIGINT)
exit(cmd.returncode)

View file

@ -0,0 +1,165 @@
import concurrent.futures
import logging
import os
import sys
from pathlib import Path
from socket import socket
from gunicorn.arbiter import Arbiter
from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict
from rich.console import Console
from rich.logging import RichHandler
from uvicorn import Server
from uvicorn.workers import UvicornWorker
from .schemas import ModelPrecision
class ClipSettings(BaseModel):
textual: str | None = None
visual: str | None = None
class FacialRecognitionSettings(BaseModel):
recognition: str | None = None
detection: str | None = None
class OcrSettings(BaseModel):
recognition: str | None = None
detection: str | None = None
class PreloadModelData(BaseModel):
clip_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__CLIP", None)
facial_recognition_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION", None)
if clip_fallback is not None:
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL"] = clip_fallback
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__VISUAL"] = clip_fallback
del os.environ["MACHINE_LEARNING_PRELOAD__CLIP"]
if facial_recognition_fallback is not None:
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION"] = facial_recognition_fallback
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION"] = facial_recognition_fallback
del os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"]
clip: ClipSettings = ClipSettings()
facial_recognition: FacialRecognitionSettings = FacialRecognitionSettings()
ocr: OcrSettings = OcrSettings()
class MaxBatchSize(BaseModel):
facial_recognition: int | None = None
text_recognition: int | None = None
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix="MACHINE_LEARNING_",
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=("settings_",),
)
cache_folder: Path = (Path.home() / ".cache" / "immich_ml").resolve()
model_ttl: int = 300
model_ttl_poll_s: int = 10
workers: int = 1
worker_timeout: int = 300
http_keepalive_timeout_s: int = 2
test_full: bool = False
request_threads: int = os.cpu_count() or 4
model_inter_op_threads: int = 0
model_intra_op_threads: int = 0
model_arena: bool = True
ann: bool = True
ann_fp16_turbo: bool = False
ann_tuning_level: int = 2
rknn: bool = True
rknn_threads: int = 1
preload: PreloadModelData | None = None
max_batch_size: MaxBatchSize | None = None
openvino_precision: ModelPrecision = ModelPrecision.FP32
@property
def device_id(self) -> str:
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
class NonPrefixedSettings(BaseSettings):
model_config = SettingsConfigDict(case_sensitive=False)
immich_host: str = "[::]"
immich_port: int = 3003
immich_log_level: str = "info"
no_color: bool = False
_clean_name = str.maketrans(":\\/", "___", ".")
def clean_name(model_name: str) -> str:
return model_name.split("/")[-1].translate(_clean_name)
LOG_LEVELS: dict[str, int] = {
"critical": logging.ERROR,
"error": logging.ERROR,
"warning": logging.WARNING,
"warn": logging.WARNING,
"info": logging.INFO,
"log": logging.INFO,
"debug": logging.DEBUG,
"verbose": logging.DEBUG,
}
settings = Settings()
non_prefixed_settings = NonPrefixedSettings()
LOG_LEVEL = LOG_LEVELS.get(non_prefixed_settings.immich_log_level.lower(), logging.INFO)
class CustomRichHandler(RichHandler):
def __init__(self) -> None:
console = Console(color_system="standard", no_color=non_prefixed_settings.no_color)
self.excluded = ["uvicorn", "starlette", "fastapi"]
super().__init__(
show_path=False,
omit_repeated_times=False,
console=console,
rich_tracebacks=True,
tracebacks_suppress=[*self.excluded, concurrent.futures],
tracebacks_show_locals=LOG_LEVEL == logging.DEBUG,
)
# hack to exclude certain modules from rich tracebacks
def emit(self, record: logging.LogRecord) -> None:
if record.exc_info is not None:
tb = record.exc_info[2]
while tb is not None:
if any(excluded in tb.tb_frame.f_code.co_filename for excluded in self.excluded):
tb.tb_frame.f_locals["_rich_traceback_omit"] = True
tb = tb.tb_next
return super().emit(record)
log = logging.getLogger("ml.log")
log.setLevel(LOG_LEVEL)
# patches this issue https://github.com/encode/uvicorn/discussions/1803
class CustomUvicornServer(Server):
async def shutdown(self, sockets: list[socket] | None = None) -> None:
for sock in sockets or []:
sock.close()
await super().shutdown()
class CustomUvicornWorker(UvicornWorker):
async def _serve(self) -> None:
self.config.app = self.wsgi
server = CustomUvicornServer(config=self.config)
self._install_sigquit_handler()
await server.serve(sockets=self.sockets)
if not server.started:
sys.exit(Arbiter.WORKER_BOOT_ERROR)

View file

@ -0,0 +1,12 @@
import os
from gunicorn.arbiter import Arbiter
from gunicorn.workers.base import Worker
device_ids = os.environ.get("MACHINE_LEARNING_DEVICE_IDS", "0").replace(" ", "").split(",")
env = os.environ
# Round-robin device assignment for each worker
def pre_fork(arbiter: Arbiter, _: Worker) -> None:
env["MACHINE_LEARNING_DEVICE_ID"] = device_ids[len(arbiter.WORKERS) % len(device_ids)]

View file

@ -0,0 +1,21 @@
{
"version": 1,
"disable_existing_loggers": false,
"handlers": {
"console": {
"class": "immich_ml.config.CustomRichHandler"
}
},
"loggers": {
"gunicorn.error": {
"handlers": [
"console"
]
}
},
"root": {
"handlers": [
"console"
]
}
}

View file

@ -0,0 +1,272 @@
import asyncio
import gc
import os
import signal
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from functools import partial
from typing import Any, AsyncGenerator, Callable, Iterator
from zipfile import BadZipFile
import orjson
from fastapi import Depends, FastAPI, File, Form, HTTPException
from fastapi.responses import ORJSONResponse, PlainTextResponse
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
from PIL.Image import Image
from pydantic import ValidationError
from starlette.formparsers import MultiPartParser
from immich_ml.models import get_model_deps
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_pil
from .config import PreloadModelData, log, settings
from .models.cache import ModelCache
from .schemas import (
InferenceEntries,
InferenceEntry,
InferenceResponse,
ModelFormat,
ModelIdentity,
ModelTask,
ModelType,
PipelineRequest,
T,
)
MultiPartParser.spool_max_size = 2**26 # spools to disk if payload is 64 MiB or larger
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
thread_pool: ThreadPoolExecutor | None = None
lock = threading.Lock()
active_requests = 0
last_called: float | None = None
@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
global thread_pool
log.info(
(
"Created in-memory cache with unloading "
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
)
)
try:
if settings.request_threads > 0:
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
asyncio.ensure_future(idle_shutdown_task())
if settings.preload is not None:
await preload_models(settings.preload)
yield
finally:
log.handlers.clear()
for model in model_cache.cache._cache.values():
del model
if thread_pool is not None:
thread_pool.shutdown()
gc.collect()
async def preload_models(preload: PreloadModelData) -> None:
log.info(f"Preloading models: clip:{preload.clip} facial_recognition:{preload.facial_recognition}")
async def load_models(model_string: str, model_type: ModelType, model_task: ModelTask) -> None:
for model_name in model_string.split(","):
model_name = model_name.strip()
model = await model_cache.get(model_name, model_type, model_task)
await load(model)
if preload.clip.textual is not None:
await load_models(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
if preload.clip.visual is not None:
await load_models(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
if preload.facial_recognition.detection is not None:
await load_models(
preload.facial_recognition.detection,
ModelType.DETECTION,
ModelTask.FACIAL_RECOGNITION,
)
if preload.facial_recognition.recognition is not None:
await load_models(
preload.facial_recognition.recognition,
ModelType.RECOGNITION,
ModelTask.FACIAL_RECOGNITION,
)
if preload.ocr.detection is not None:
await load_models(
preload.ocr.detection,
ModelType.DETECTION,
ModelTask.OCR,
)
if preload.ocr.recognition is not None:
await load_models(
preload.ocr.recognition,
ModelType.RECOGNITION,
ModelTask.OCR,
)
if preload.clip_fallback is not None:
log.warning(
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__CLIP'. "
"Use 'MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL' and "
"'MACHINE_LEARNING_PRELOAD__CLIP__VISUAL' instead."
)
if preload.facial_recognition_fallback is not None:
log.warning(
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION'. "
"Use 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION' and "
"'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION' instead."
)
def update_state() -> Iterator[None]:
global active_requests, last_called
active_requests += 1
last_called = time.time()
try:
yield
finally:
active_requests -= 1
def get_entries(entries: str = Form()) -> InferenceEntries:
try:
request: PipelineRequest = orjson.loads(entries)
without_deps: list[InferenceEntry] = []
with_deps: list[InferenceEntry] = []
for task, types in request.items():
for type, entry in types.items():
parsed: InferenceEntry = {
"name": entry["modelName"],
"task": task,
"type": type,
"options": entry.get("options", {}),
}
dep = get_model_deps(parsed["name"], type, task)
(with_deps if dep else without_deps).append(parsed)
return without_deps, with_deps
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
log.error(f"Invalid request format: {e}")
raise HTTPException(422, "Invalid request format.")
app = FastAPI(lifespan=lifespan)
@app.get("/")
async def root() -> ORJSONResponse:
return ORJSONResponse({"message": "Immich ML"})
@app.get("/ping")
def ping() -> PlainTextResponse:
return PlainTextResponse("pong")
@app.post("/predict", dependencies=[Depends(update_state)])
async def predict(
entries: InferenceEntries = Depends(get_entries),
image: bytes | None = File(default=None),
text: str | None = Form(default=None),
) -> Any:
if image is not None:
inputs: Image | str = await run(lambda: decode_pil(image))
elif text is not None:
inputs = text
else:
raise HTTPException(400, "Either image or text must be provided")
response = await run_inference(inputs, entries)
return ORJSONResponse(response)
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
outputs: dict[ModelIdentity, Any] = {}
response: InferenceResponse = {}
async def _run_inference(entry: InferenceEntry) -> None:
model = await model_cache.get(
entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl, **entry["options"]
)
inputs = [payload]
for dep in model.depends:
try:
inputs.append(outputs[dep])
except KeyError:
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
raise HTTPException(400, message)
model = await load(model)
output = await run(model.predict, *inputs, **entry["options"])
outputs[model.identity] = output
response[entry["task"]] = output
without_deps, with_deps = entries
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
if with_deps:
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
if isinstance(payload, Image):
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
return response
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
if thread_pool is None:
return func(*args, **kwargs)
partial_func = partial(func, *args, **kwargs)
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
async def load(model: InferenceModel) -> InferenceModel:
if model.loaded:
return model
def _load(model: InferenceModel) -> InferenceModel:
if model.load_attempts > 1:
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
with lock:
try:
model.load()
except FileNotFoundError as e:
if model.model_format == ModelFormat.ONNX:
raise e
log.warning(
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
exc_info=e,
)
model.model_format = ModelFormat.ONNX
model.load()
return model
try:
return await run(_load, model)
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
model.clear_cache()
return await run(_load, model)
async def idle_shutdown_task() -> None:
while True:
if (
last_called is not None
and not active_requests
and not lock.locked()
and time.time() - last_called > settings.model_ttl
):
log.info("Shutting down due to inactivity.")
os.kill(os.getpid(), signal.SIGINT)
break
await asyncio.sleep(settings.model_ttl_poll_s)

View file

@ -0,0 +1,48 @@
from typing import Any
from immich_ml.models.base import InferenceModel
from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
from immich_ml.models.clip.visual import OpenClipVisualEncoder
from immich_ml.models.ocr.detection import TextDetector
from immich_ml.models.ocr.recognition import TextRecognizer
from immich_ml.schemas import ModelSource, ModelTask, ModelType
from .constants import get_model_source
from .facial_recognition.detection import FaceDetector
from .facial_recognition.recognition import FaceRecognizer
def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
source = get_model_source(model_name)
match source, model_type, model_task:
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
return OpenClipVisualEncoder
case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return OpenClipTextualEncoder
case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
return MClipTextualEncoder
case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
return FaceDetector
case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
return FaceRecognizer
case ModelSource.PADDLE, ModelType.DETECTION, ModelTask.OCR:
return TextDetector
case ModelSource.PADDLE, ModelType.RECOGNITION, ModelTask.OCR:
return TextRecognizer
case _:
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")
def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
return get_model_class(model_name, model_type, model_task).depends

View file

@ -0,0 +1,176 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Any, ClassVar
from huggingface_hub import snapshot_download
import immich_ml.sessions.ann.loader
import immich_ml.sessions.rknn as rknn
from immich_ml.sessions.ort import OrtSession
from ..config import clean_name, log, settings
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
from ..sessions.ann import AnnSession
class InferenceModel(ABC):
depends: ClassVar[list[ModelIdentity]]
identity: ClassVar[ModelIdentity]
def __init__(
self,
model_name: str,
cache_dir: Path | str | None = None,
model_format: ModelFormat | None = None,
session: ModelSession | None = None,
**model_kwargs: Any,
) -> None:
self.loaded = session is not None
self.load_attempts = 0
self.model_name = clean_name(model_name)
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
self.model_format = model_format if model_format is not None else self._model_format_default
if session is not None:
self.session = session
def download(self) -> None:
if not self.cached:
model_type = self.model_type.replace("-", " ")
log.info(f"Downloading {model_type} model '{self.model_name}' to {self.model_path}. This may take a while.")
self._download()
def load(self) -> None:
if self.loaded:
return
self.load_attempts += 1
self.download()
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
self.session = self._load()
self.loaded = True
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
self.load()
if model_kwargs:
self.configure(**model_kwargs)
return self._predict(*inputs)
@abstractmethod
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
def configure(self, **kwargs: Any) -> None:
pass
def _download(self) -> None:
ignored_patterns: dict[ModelFormat, list[str]] = {
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
ModelFormat.ARMNN: ["*.rknn"],
ModelFormat.RKNN: ["*.armnn"],
}
snapshot_download(
f"immich-app/{clean_name(self.model_name)}",
cache_dir=self.cache_dir,
local_dir=self.cache_dir,
ignore_patterns=ignored_patterns.get(self.model_format, []),
)
def _load(self) -> ModelSession:
return self._make_session(self.model_path)
def clear_cache(self) -> None:
if not self.cache_dir.exists():
log.warning(
f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
)
return
if not rmtree.avoids_symlink_attacks:
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")
if self.cache_dir.is_dir():
log.info(f"Cleared cache directory for model '{self.model_name}'.")
rmtree(self.cache_dir)
else:
log.warning(
(
f"Encountered file instead of directory at cache path "
f"for '{self.model_name}'. Removing file and replacing with a directory."
),
)
self.cache_dir.unlink()
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _make_session(self, model_path: Path) -> ModelSession:
if not model_path.is_file():
raise FileNotFoundError(f"Model file not found: {model_path}")
match model_path.suffix:
case ".armnn":
session: ModelSession = AnnSession(model_path)
case ".onnx":
session = OrtSession(model_path)
case ".rknn":
session = rknn.RknnSession(model_path)
case _:
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
return session
def model_path_for_format(self, model_format: ModelFormat) -> Path:
model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
if model_path_prefix:
return self.model_dir / model_path_prefix / f"model.{model_format}"
return self.model_dir / f"model.{model_format}"
@property
def model_dir(self) -> Path:
return self.cache_dir / self.model_type.value
@property
def model_path(self) -> Path:
return self.model_path_for_format(self.model_format)
@property
def model_task(self) -> ModelTask:
return self.identity[1]
@property
def model_type(self) -> ModelType:
return self.identity[0]
@property
def cache_dir(self) -> Path:
return self._cache_dir
@cache_dir.setter
def cache_dir(self, cache_dir: Path) -> None:
self._cache_dir = cache_dir
@property
def _cache_dir_default(self) -> Path:
return settings.cache_folder / self.model_task.value / self.model_name
@property
def cached(self) -> bool:
return self.model_path.is_file()
@property
def model_format(self) -> ModelFormat:
return self._model_format
@model_format.setter
def model_format(self, model_format: ModelFormat) -> None:
log.debug(f"Setting model format to {model_format}")
self._model_format = model_format
@property
def _model_format_default(self) -> ModelFormat:
if rknn.is_available:
return ModelFormat.RKNN
elif immich_ml.sessions.ann.loader.is_available and settings.ann:
return ModelFormat.ARMNN
else:
return ModelFormat.ONNX

View file

@ -0,0 +1,60 @@
from typing import Any
from aiocache.backends.memory import SimpleMemoryCache
from aiocache.lock import OptimisticLock
from aiocache.plugins import TimingPlugin
from immich_ml.models import from_model_type
from immich_ml.models.base import InferenceModel
from ..schemas import ModelTask, ModelType, has_profiling
class ModelCache:
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
def __init__(
self,
revalidate: bool = False,
timeout: int | None = None,
profiling: bool = False,
) -> None:
"""
Args:
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
"""
plugins = []
if profiling:
plugins.append(TimingPlugin())
self.should_revalidate = revalidate
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
async def get(
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
) -> InferenceModel:
key = f"{model_name}{model_type}{model_task}"
async with OptimisticLock(self.cache, key) as lock:
model: InferenceModel | None = await self.cache.get(key)
if model is None:
model = from_model_type(model_name, model_type, model_task, **model_kwargs)
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
elif self.should_revalidate:
await self.revalidate(key, model_kwargs.get("ttl", None))
return model
async def get_profiling(self) -> dict[str, float] | None:
if not has_profiling(self.cache):
return None
return self.cache.profiling
async def revalidate(self, key: str, ttl: int | None) -> None:
if ttl is not None and key in self.cache._handlers:
await self.cache.expire(key, ttl)

View file

@ -0,0 +1,120 @@
import json
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any
import numpy as np
from numpy.typing import NDArray
from tokenizers import Encoding, Tokenizer
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.models.constants import WEBLATE_TO_FLORES200
from immich_ml.models.transforms import clean_text, serialize_np_array
from immich_ml.schemas import ModelSession, ModelTask, ModelType
class BaseCLIPTextualEncoder(InferenceModel):
depends = []
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
def _predict(self, inputs: str, language: str | None = None) -> str:
tokens = self.tokenize(inputs, language=language)
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
return serialize_np_array(res)
def _load(self) -> ModelSession:
session = super()._load()
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
self.tokenizer = self._load_tokenizer()
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
self.is_nllb = self.model_name.startswith("nllb")
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
return session
@abstractmethod
def _load_tokenizer(self) -> Tokenizer:
pass
@abstractmethod
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
pass
@property
def model_cfg_path(self) -> Path:
return self.cache_dir / "config.json"
@property
def tokenizer_file_path(self) -> Path:
return self.model_dir / "tokenizer.json"
@property
def tokenizer_cfg_path(self) -> Path:
return self.model_dir / "tokenizer_config.json"
@cached_property
def model_cfg(self) -> dict[str, Any]:
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
return model_cfg
@property
def text_cfg(self) -> dict[str, Any]:
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
return text_cfg
@cached_property
def tokenizer_file(self) -> dict[str, Any]:
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
return tokenizer_file
@cached_property
def tokenizer_cfg(self) -> dict[str, Any]:
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
return tokenizer_cfg
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
def _load_tokenizer(self) -> Tokenizer:
context_length: int = self.text_cfg.get("context_length", 77)
pad_token: str = self.tokenizer_cfg["pad_token"]
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
pad_id: int = tokenizer.token_to_id(pad_token)
tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
tokenizer.enable_truncation(max_length=context_length)
return tokenizer
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
if self.is_nllb and language is not None:
flores_code = WEBLATE_TO_FLORES200.get(language)
if flores_code is None:
no_country = language.split("-")[0]
flores_code = WEBLATE_TO_FLORES200.get(no_country)
if flores_code is None:
log.warning(f"Language '{language}' not found, defaulting to 'en'")
flores_code = "eng_Latn"
text = f"{flores_code}{text}"
tokens: Encoding = self.tokenizer.encode(text)
return {"text": np.array([tokens.ids], dtype=np.int32)}
class MClipTextualEncoder(OpenClipTextualEncoder):
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
text = clean_text(text, canonicalize=self.canonicalize)
tokens: Encoding = self.tokenizer.encode(text)
return {
"input_ids": np.array([tokens.ids], dtype=np.int32),
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
}

View file

@ -0,0 +1,77 @@
import json
from abc import abstractmethod
from functools import cached_property
from pathlib import Path
from typing import Any
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import (
crop_pil,
decode_pil,
get_pil_resampling,
normalize,
resize_pil,
serialize_np_array,
to_numpy,
)
from immich_ml.schemas import ModelSession, ModelTask, ModelType
class BaseCLIPVisualEncoder(InferenceModel):
depends = []
identity = (ModelType.VISUAL, ModelTask.SEARCH)
def _predict(self, inputs: Image.Image | bytes) -> str:
image = decode_pil(inputs)
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
return serialize_np_array(res)
@abstractmethod
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
pass
@property
def model_cfg_path(self) -> Path:
return self.cache_dir / "config.json"
@property
def preprocess_cfg_path(self) -> Path:
return self.model_dir / "preprocess_cfg.json"
@cached_property
def model_cfg(self) -> dict[str, Any]:
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
return model_cfg
@cached_property
def preprocess_cfg(self) -> dict[str, Any]:
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
return preprocess_cfg
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
def _load(self) -> ModelSession:
size: list[int] | int = self.preprocess_cfg["size"]
self.size = size[0] if isinstance(size, list) else size
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
return super()._load()
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
image = resize_pil(image, self.size)
image = crop_pil(image, self.size)
image_np = to_numpy(image)
image_np = normalize(image_np, self.mean, self.std)
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}

View file

@ -0,0 +1,178 @@
from immich_ml.config import clean_name
from immich_ml.schemas import ModelSource
_OPENCLIP_MODELS = {
"RN101__openai",
"RN101__yfcc15m",
"RN50__cc12m",
"RN50__openai",
"RN50__yfcc15m",
"RN50x16__openai",
"RN50x4__openai",
"RN50x64__openai",
"ViT-B-16-SigLIP-256__webli",
"ViT-B-16-SigLIP-384__webli",
"ViT-B-16-SigLIP-512__webli",
"ViT-B-16-SigLIP-i18n-256__webli",
"ViT-B-16-SigLIP__webli",
"ViT-B-16-plus-240__laion400m_e31",
"ViT-B-16-plus-240__laion400m_e32",
"ViT-B-16__laion400m_e31",
"ViT-B-16__laion400m_e32",
"ViT-B-16__openai",
"ViT-B-32__laion2b-s34b-b79k",
"ViT-B-32__laion2b_e16",
"ViT-B-32__laion400m_e31",
"ViT-B-32__laion400m_e32",
"ViT-B-32__openai",
"ViT-H-14-378-quickgelu__dfn5b",
"ViT-H-14-quickgelu__dfn5b",
"ViT-H-14__laion2b-s32b-b79k",
"ViT-L-14-336__openai",
"ViT-L-14-quickgelu__dfn2b",
"ViT-L-14__laion2b-s32b-b82k",
"ViT-L-14__laion400m_e31",
"ViT-L-14__laion400m_e32",
"ViT-L-14__openai",
"ViT-L-16-SigLIP-256__webli",
"ViT-L-16-SigLIP-384__webli",
"ViT-SO400M-14-SigLIP-384__webli",
"ViT-g-14__laion2b-s12b-b42k",
"XLM-Roberta-Base-ViT-B-32__laion5b_s13b_b90k",
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
"nllb-clip-base-siglip__mrl",
"nllb-clip-base-siglip__v1",
"nllb-clip-large-siglip__mrl",
"nllb-clip-large-siglip__v1",
"ViT-B-16-SigLIP2__webli",
"ViT-B-32-SigLIP2-256__webli",
"ViT-L-16-SigLIP2-256__webli",
"ViT-L-16-SigLIP2-384__webli",
"ViT-L-16-SigLIP2-512__webli",
"ViT-SO400M-14-SigLIP2-378__webli",
"ViT-SO400M-14-SigLIP2__webli",
"ViT-SO400M-16-SigLIP2-256__webli",
"ViT-SO400M-16-SigLIP2-384__webli",
"ViT-SO400M-16-SigLIP2-512__webli",
"ViT-gopt-16-SigLIP2-256__webli",
"ViT-gopt-16-SigLIP2-384__webli",
}
_MCLIP_MODELS = {
"LABSE-Vit-L-14",
"XLM-Roberta-Large-Vit-B-16Plus",
"XLM-Roberta-Large-Vit-B-32",
"XLM-Roberta-Large-Vit-L-14",
}
_INSIGHTFACE_MODELS = {
"antelopev2",
"buffalo_s",
"buffalo_m",
"buffalo_l",
}
_PADDLE_MODELS = {
"PP-OCRv5_server",
"PP-OCRv5_mobile",
"CH__PP-OCRv5_server",
"CH__PP-OCRv5_mobile",
"EL__PP-OCRv5_mobile",
"EN__PP-OCRv5_mobile",
"ESLAV__PP-OCRv5_mobile",
"KOREAN__PP-OCRv5_mobile",
"LATIN__PP-OCRv5_mobile",
"TH__PP-OCRv5_mobile",
}
SUPPORTED_PROVIDERS = [
"CUDAExecutionProvider",
"ROCMExecutionProvider",
"OpenVINOExecutionProvider",
"CoreMLExecutionProvider",
"CPUExecutionProvider",
]
RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
WEBLATE_TO_FLORES200 = {
"af": "afr_Latn",
"ar": "arb_Arab",
"az": "azj_Latn",
"be": "bel_Cyrl",
"bg": "bul_Cyrl",
"ca": "cat_Latn",
"cs": "ces_Latn",
"da": "dan_Latn",
"de": "deu_Latn",
"el": "ell_Grek",
"en": "eng_Latn",
"es": "spa_Latn",
"et": "est_Latn",
"fa": "pes_Arab",
"fi": "fin_Latn",
"fr": "fra_Latn",
"he": "heb_Hebr",
"hi": "hin_Deva",
"hr": "hrv_Latn",
"hu": "hun_Latn",
"hy": "hye_Armn",
"id": "ind_Latn",
"it": "ita_Latn",
"ja": "jpn_Hira",
"kmr": "kmr_Latn",
"ko": "kor_Hang",
"lb": "ltz_Latn",
"lt": "lit_Latn",
"lv": "lav_Latn",
"mfa": "zsm_Latn",
"mk": "mkd_Cyrl",
"mn": "khk_Cyrl",
"mr": "mar_Deva",
"ms": "zsm_Latn",
"nb-NO": "nob_Latn",
"nn": "nno_Latn",
"nl": "nld_Latn",
"pl": "pol_Latn",
"pt-BR": "por_Latn",
"pt": "por_Latn",
"ro": "ron_Latn",
"ru": "rus_Cyrl",
"sk": "slk_Latn",
"sl": "slv_Latn",
"sr-Cyrl": "srp_Cyrl",
"sv": "swe_Latn",
"ta": "tam_Taml",
"te": "tel_Telu",
"th": "tha_Thai",
"tr": "tur_Latn",
"uk": "ukr_Cyrl",
"ur": "urd_Arab",
"vi": "vie_Latn",
"zh-CN": "zho_Hans",
"zh-Hans": "zho_Hans",
"zh-TW": "zho_Hant",
}
def get_model_source(model_name: str) -> ModelSource | None:
cleaned_name = clean_name(model_name)
if cleaned_name in _INSIGHTFACE_MODELS:
return ModelSource.INSIGHTFACE
if cleaned_name in _MCLIP_MODELS:
return ModelSource.MCLIP
if cleaned_name in _OPENCLIP_MODELS:
return ModelSource.OPENCLIP
if cleaned_name in _PADDLE_MODELS:
return ModelSource.PADDLE
return None

View file

@ -0,0 +1,41 @@
from typing import Any
import numpy as np
from insightface.model_zoo import RetinaFace
from numpy.typing import NDArray
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_cv2
from immich_ml.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
class FaceDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
self.min_score = model_kwargs.pop("minScore", min_score)
super().__init__(model_name, **model_kwargs)
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
self.model = RetinaFace(session=session)
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
return session
def _predict(self, inputs: NDArray[np.uint8] | bytes) -> FaceDetectionOutput:
inputs = decode_cv2(inputs)
bboxes, landmarks = self._detect(inputs)
return {
"boxes": bboxes[:, :4].round(),
"scores": bboxes[:, 4],
"landmarks": landmarks,
}
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
return self.model.detect(inputs) # type: ignore
def configure(self, **kwargs: Any) -> None:
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)

View file

@ -0,0 +1,92 @@
from pathlib import Path
from typing import Any
import numpy as np
import onnx
import onnxruntime as ort
from insightface.model_zoo import ArcFaceONNX
from insightface.utils.face_align import norm_crop
from numpy.typing import NDArray
from onnx.tools.update_model_dims import update_inputs_outputs_dims
from PIL import Image
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_cv2, serialize_np_array
from immich_ml.schemas import (
FaceDetectionOutput,
FacialRecognitionOutput,
ModelFormat,
ModelSession,
ModelTask,
ModelType,
)
class FaceRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs)
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
def _load(self) -> ModelSession:
session = self._make_session(self.model_path)
if (not self.batch_size or self.batch_size > 1) and str(session.get_inputs()[0].shape[0]) != "batch":
self._add_batch_axis(self.model_path)
session = self._make_session(self.model_path)
self.model = ArcFaceONNX(
self.model_path_for_format(ModelFormat.ONNX).as_posix(),
session=session,
)
return session
def _predict(
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput
) -> FacialRecognitionOutput:
if faces["boxes"].shape[0] == 0:
return []
inputs = decode_cv2(inputs)
cropped_faces = self._crop(inputs, faces)
embeddings = self._predict_batch(cropped_faces)
return self.postprocess(faces, embeddings)
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
if not self.batch_size or len(cropped_faces) <= self.batch_size:
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
return embeddings
batch_embeddings: list[NDArray[np.float32]] = []
for i in range(0, len(cropped_faces), self.batch_size):
batch_embeddings.append(self.model.get_feat(cropped_faces[i : i + self.batch_size]))
return np.concatenate(batch_embeddings, axis=0)
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
return [
{
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
"embedding": serialize_np_array(embedding),
"score": score,
}
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
]
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
def _add_batch_axis(self, model_path: Path) -> None:
log.debug(f"Adding batch axis to model {model_path}")
proto = onnx.load(model_path)
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
onnx.save(updated_proto, model_path)
@property
def _batch_size_default(self) -> int | None:
providers = ort.get_available_providers()
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1

View file

@ -0,0 +1,125 @@
from typing import Any
import cv2
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from rapidocr.ch_ppocr_det.utils import DBPostProcess
from rapidocr.inference_engine.base import FileInfo, InferSession
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log
from immich_ml.models.base import InferenceModel
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import TextDetectionOutput
class TextDetector(InferenceModel):
depends = []
identity = (ModelType.DETECTION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
self.max_resolution = 736
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
self._empty: TextDetectionOutput = {
"boxes": np.empty(0, dtype=np.float32),
"scores": np.empty(0, dtype=np.float32),
}
self.postprocess = DBPostProcess(
thresh=0.3,
box_thresh=model_kwargs.get("minScore", 0.5),
max_candidates=1000,
unclip_ratio=1.6,
use_dilation=True,
score_mode="fast",
)
def _download(self) -> None:
model_info = InferSession.get_model_url(
FileInfo(
engine_type=EngineType.ONNXRUNTIME,
ocr_version=OCRVersion.PPOCRV5,
task_type=TaskType.DET,
lang_type=LangDet.CH,
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
)
)
download_params = DownloadFileInput(
file_url=model_info["model_dir"],
sha256=model_info["SHA256"],
save_path=self.model_path,
logger=log,
)
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
# TODO: support other runtime sessions
return OrtSession(self.model_path)
# partly adapted from RapidOCR
def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
w, h = inputs.size
if w < 32 or h < 32:
return self._empty
out = self.session.run(None, {"x": self._transform(inputs)})[0]
boxes, scores = self.postprocess(out, (h, w))
if len(boxes) == 0:
return self._empty
return {
"boxes": self.sorted_boxes(boxes),
"scores": np.array(scores, dtype=np.float32),
}
# adapted from RapidOCR
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
if img.height < img.width:
ratio = float(self.max_resolution) / img.height
else:
ratio = float(self.max_resolution) / img.width
ratio = min(ratio, 1.0)
resize_h = int(img.height * ratio)
resize_w = int(img.width * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
img_np -= self.mean
img_np *= self.std_inv
img_np = np.transpose(img_np, (2, 0, 1))
return np.expand_dims(img_np, axis=0)
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
if len(dt_boxes) == 0:
return dt_boxes
# Sort by y, then identify lines, then sort by (line, x)
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
sorted_y = dt_boxes[y_order, 0, 1]
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
line_ids[0] = 0
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
# Create composite sort key for final ordering
# Shift line_ids by large factor, add x for tie-breaking
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
final_order = np.argsort(sort_key, kind="stable")
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
return sorted_boxes
def configure(self, **kwargs: Any) -> None:
if (max_resolution := kwargs.get("maxResolution")) is not None:
self.max_resolution = max_resolution
if (min_score := kwargs.get("minScore")) is not None:
self.postprocess.box_thresh = min_score
if (score_mode := kwargs.get("scoreMode")) is not None:
self.postprocess.score_mode = score_mode

View file

@ -0,0 +1,153 @@
from typing import Any
import numpy as np
from numpy.typing import NDArray
from PIL import Image
from rapidocr.ch_ppocr_rec import TextRecInput
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
from rapidocr.inference_engine.base import FileInfo, InferSession
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
from rapidocr.utils.typings import EngineType, LangRec, OCRVersion, TaskType
from rapidocr.utils.typings import ModelType as RapidModelType
from rapidocr.utils.vis_res import VisRes
from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import pil_to_cv2
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
class TextRecognizer(InferenceModel):
depends = [(ModelType.DETECTION, ModelTask.OCR)]
identity = (ModelType.RECOGNITION, ModelTask.OCR)
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
self.min_score = model_kwargs.get("minScore", 0.9)
self._empty: TextRecognitionOutput = {
"box": np.empty(0, dtype=np.float32),
"boxScore": np.empty(0, dtype=np.float32),
"text": [],
"textScore": np.empty(0, dtype=np.float32),
}
VisRes.__init__ = lambda self, **kwargs: None # pyright: ignore[reportAttributeAccessIssue]
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
def _download(self) -> None:
model_info = InferSession.get_model_url(
FileInfo(
engine_type=EngineType.ONNXRUNTIME,
ocr_version=OCRVersion.PPOCRV5,
task_type=TaskType.REC,
lang_type=self.language,
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
)
)
download_params = DownloadFileInput(
file_url=model_info["model_dir"],
sha256=model_info["SHA256"],
save_path=self.model_path,
logger=log,
)
DownloadFile.run(download_params)
def _load(self) -> ModelSession:
# TODO: support other runtimes
session = OrtSession(self.model_path)
self.model = RapidTextRecognizer(
OcrOptions(
session=session.session,
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
rec_img_shape=(3, 48, 320),
lang_type=self.language,
)
)
return session
def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
boxes, box_scores = texts["boxes"], texts["scores"]
if boxes.shape[0] == 0:
return self._empty
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
if rec.txts is None:
return self._empty
boxes[:, :, 0] /= img.width
boxes[:, :, 1] /= img.height
text_scores = np.array(rec.scores)
valid_text_score_idx = text_scores > self.min_score
valid_score_idx_list = valid_text_score_idx.tolist()
return {
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
"boxScore": box_scores[valid_text_score_idx],
"textScore": text_scores[valid_text_score_idx],
}
def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]:
img_crop_width = np.maximum(
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
).astype(np.int32)
img_crop_height = np.maximum(
np.linalg.norm(boxes[:, 0] - boxes[:, 3], axis=1), np.linalg.norm(boxes[:, 1] - boxes[:, 2], axis=1)
).astype(np.int32)
pts_std = np.zeros((img_crop_width.shape[0], 4, 2), dtype=np.float32)
pts_std[:, 1:3, 0] = img_crop_width[:, None]
pts_std[:, 2:4, 1] = img_crop_height[:, None]
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1)
all_coeffs = self._get_perspective_transform(pts_std, boxes)
imgs: list[NDArray[np.uint8]] = []
for coeffs, dst_size in zip(all_coeffs, img_crop_sizes):
dst_img = img.transform(
size=tuple(dst_size),
method=Image.Transform.PERSPECTIVE,
data=tuple(coeffs),
resample=Image.Resampling.BICUBIC,
)
dst_width, dst_height = dst_img.size
if dst_height * 1.0 / dst_width >= 1.5:
dst_img = dst_img.rotate(90, expand=True)
imgs.append(pil_to_cv2(dst_img))
return imgs
def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]:
N = src.shape[0]
x, y = src[:, :, 0], src[:, :, 1]
u, v = dst[:, :, 0], dst[:, :, 1]
A = np.zeros((N, 8, 9), dtype=np.float32)
# Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
A[:, ::2, 0] = x
A[:, ::2, 1] = y
A[:, ::2, 2] = 1
A[:, ::2, 6] = -u * x
A[:, ::2, 7] = -u * y
A[:, ::2, 8] = -u
# Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
A[:, 1::2, 3] = x
A[:, 1::2, 4] = y
A[:, 1::2, 5] = 1
A[:, 1::2, 6] = -v * x
A[:, 1::2, 7] = -v * y
A[:, 1::2, 8] = -v
# Solve using SVD for all matrices at once
_, _, Vt = np.linalg.svd(A)
H = Vt[:, -1, :].reshape(N, 3, 3)
H = H / H[:, 2:3, 2:3]
# Extract the 8 coefficients for each transformation
return np.column_stack(
[H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]]
) # pyright: ignore[reportReturnType]
def configure(self, **kwargs: Any) -> None:
self.min_score = kwargs.get("minScore", self.min_score)

View file

@ -0,0 +1,27 @@
from typing import Any, Iterable
import numpy as np
import numpy.typing as npt
from rapidocr.utils.typings import EngineType, LangRec
from typing_extensions import TypedDict
class TextDetectionOutput(TypedDict):
boxes: npt.NDArray[np.float32]
scores: npt.NDArray[np.float32]
class TextRecognitionOutput(TypedDict):
box: npt.NDArray[np.float32]
boxScore: npt.NDArray[np.float32]
text: Iterable[str]
textScore: npt.NDArray[np.float32]
# RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes
class OcrOptions(dict[str, Any]):
def __init__(self, lang_type: LangRec | None = None, **options: Any) -> None:
super().__init__(**options)
self.engine_type = EngineType.ONNXRUNTIME
self.lang_type = lang_type
self.font_path = None

View file

@ -0,0 +1,80 @@
import string
from io import BytesIO
from typing import IO
import cv2
import numpy as np
import orjson
from numpy.typing import NDArray
from PIL import Image
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
def resize_pil(img: Image.Image, size: int) -> Image.Image:
if img.width < img.height:
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
else:
return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC)
# https://stackoverflow.com/a/60883103
def crop_pil(img: Image.Image, size: int) -> Image.Image:
left = int((img.size[0] / 2) - (size / 2))
upper = int((img.size[1] / 2) - (size / 2))
right = left + size
lower = upper + size
return img.crop((left, upper, right, lower))
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
def normalize(
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
) -> NDArray[np.float32]:
return (img - mean) / std
def get_pil_resampling(resample: str) -> Image.Resampling:
return _PIL_RESAMPLING_METHODS[resample.lower()]
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
if isinstance(image_bytes, Image.Image):
return image_bytes
image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
image.load()
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
match image_bytes:
case bytes() | memoryview() | bytearray():
return pil_to_cv2(decode_pil(image_bytes)) # pillow is much faster than cv2
case Image.Image():
return pil_to_cv2(image_bytes)
case _:
return image_bytes
def clean_text(text: str, canonicalize: bool = False) -> str:
text = " ".join(text.split())
if canonicalize:
text = text.translate(_PUNCTUATION_TRANS).lower()
return text
# this allows the client to use the array as a string without deserializing only to serialize back to a string
# TODO: use this in a less invasive way
def serialize_np_array(arr: NDArray[np.float32]) -> str:
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()

View file

@ -0,0 +1,122 @@
from enum import Enum
from typing import Any, Literal, Protocol, TypeGuard, TypeVar
import numpy as np
import numpy.typing as npt
from typing_extensions import TypedDict
class StrEnum(str, Enum):
value: str
def __str__(self) -> str:
return self.value
class BoundingBox(TypedDict):
x1: int
y1: int
x2: int
y2: int
class ModelTask(StrEnum):
FACIAL_RECOGNITION = "facial-recognition"
SEARCH = "clip"
OCR = "ocr"
class ModelType(StrEnum):
DETECTION = "detection"
RECOGNITION = "recognition"
TEXTUAL = "textual"
VISUAL = "visual"
class ModelFormat(StrEnum):
ARMNN = "armnn"
ONNX = "onnx"
RKNN = "rknn"
class ModelSource(StrEnum):
INSIGHTFACE = "insightface"
MCLIP = "mclip"
OPENCLIP = "openclip"
PADDLE = "paddle"
class ModelPrecision(StrEnum):
FP16 = "FP16"
FP32 = "FP32"
ModelIdentity = tuple[ModelType, ModelTask]
class SessionNode(Protocol):
@property
def name(self) -> str | None: ...
@property
def shape(self) -> tuple[int, ...]: ...
class ModelSession(Protocol):
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, npt.NDArray[np.float32]] | dict[str, npt.NDArray[np.int32]],
run_options: Any = None,
) -> list[npt.NDArray[np.float32]]: ...
def get_inputs(self) -> list[SessionNode]: ...
def get_outputs(self) -> list[SessionNode]: ...
class HasProfiling(Protocol):
profiling: dict[str, float]
class FaceDetectionOutput(TypedDict):
boxes: npt.NDArray[np.float32]
scores: npt.NDArray[np.float32]
landmarks: npt.NDArray[np.float32]
class DetectedFace(TypedDict):
boundingBox: BoundingBox
embedding: str
score: float
FacialRecognitionOutput = list[DetectedFace]
class PipelineEntry(TypedDict):
modelName: str
options: dict[str, Any]
PipelineRequest = dict[ModelTask, dict[ModelType, PipelineEntry]]
class InferenceEntry(TypedDict):
name: str
task: ModelTask
type: ModelType
options: dict[str, Any]
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
InferenceResponse = dict[ModelTask | Literal["imageHeight"] | Literal["imageWidth"], Any]
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
return hasattr(obj, "profiling") and isinstance(obj.profiling, dict)
T = TypeVar("T")

View file

@ -0,0 +1,58 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
from numpy.typing import NDArray
from immich_ml.config import log, settings
from immich_ml.schemas import SessionNode
from .loader import Ann
class AnnSession:
"""
Wrapper for ANN to be drop-in replacement for ONNX session.
"""
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
self.model_path = model_path
self.cache_dir = cache_dir
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
log.info("Loading ANN model %s ...", model_path)
self.model = self.ann.load(
model_path.as_posix(),
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
fp16=settings.ann_fp16_turbo,
)
log.info("Loaded ANN model with ID %d", self.model)
def __del__(self) -> None:
self.ann.unload(self.model)
log.info("Unloaded ANN model %d", self.model)
self.ann.destroy()
def get_inputs(self) -> list[SessionNode]:
shapes = self.ann.input_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def get_outputs(self) -> list[SessionNode]:
shapes = self.ann.output_shapes[self.model]
return [AnnNode(None, s) for s in shapes]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
return self.ann.execute(self.model, inputs)
class AnnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]

View file

@ -0,0 +1,169 @@
from __future__ import annotations
from ctypes import CDLL, Array, c_bool, c_char_p, c_int, c_ulong, c_void_p
from os.path import exists
from typing import Any, Protocol, TypeVar
import numpy as np
from numpy.typing import NDArray
from immich_ml.config import log
try:
CDLL("libmali.so") # fail if libmali.so is not mounted into container
libann = CDLL("libann.so")
libann.init.argtypes = c_int, c_int, c_char_p
libann.init.restype = c_void_p
libann.load.argtypes = c_void_p, c_char_p, c_bool, c_bool, c_bool, c_char_p
libann.load.restype = c_int
libann.execute.argtypes = c_void_p, c_int, Array[c_void_p], Array[c_void_p]
libann.unload.argtypes = c_void_p, c_int
libann.destroy.argtypes = (c_void_p,)
libann.shape.argtypes = c_void_p, c_int, c_bool, c_int
libann.shape.restype = c_ulong
libann.tensors.argtypes = c_void_p, c_int, c_bool
libann.tensors.restype = c_int
is_available = True
except OSError as e:
log.debug("Could not load ANN shared libraries, using ONNX: %s", e)
is_available = False
T = TypeVar("T", covariant=True)
class Newable(Protocol[T]):
def new(self) -> None: ...
class _Singleton(type, Newable[T]):
_instances: dict[_Singleton[T], Newable[T]] = {}
def __call__(cls, *args: Any, **kwargs: Any) -> Newable[T]:
if cls not in cls._instances:
obj: Newable[T] = super(_Singleton, cls).__call__(*args, **kwargs)
cls._instances[cls] = obj
else:
obj = cls._instances[cls]
obj.new()
return obj
class Ann(metaclass=_Singleton):
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
if not is_available:
raise RuntimeError("libann is not available!")
if tuning_level == 0 and tuning_file is None:
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
if tuning_level < 0 or tuning_level > 3:
raise ValueError("tuning_level must be 0 (load from tuning_file), 1, 2 or 3.")
if log_level < 0 or log_level > 5:
raise ValueError("log_level must be 0 (trace), 1 (debug), 2 (info), 3 (warning), 4 (error) or 5 (fatal)")
self.log_level = log_level
self.tuning_level = tuning_level
self.tuning_file = tuning_file
self.output_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
self.ann: int | None = None
self.new()
if self.tuning_file is not None:
# make sure tuning file exists (without clearing contents)
# once filled, the tuning file reduces the cost/time of the first
# inference after model load by 10s of seconds
open(self.tuning_file, "a").close()
def new(self) -> None:
if self.ann is None:
self.ann = libann.init(
self.log_level,
self.tuning_level,
self.tuning_file.encode() if self.tuning_file is not None else None,
)
self.ref_count = 0
self.ref_count += 1
def destroy(self) -> None:
self.ref_count -= 1
if self.ref_count <= 0 and self.ann is not None:
libann.destroy(self.ann)
self.ann = None
def __del__(self) -> None:
if self.ann is not None:
libann.destroy(self.ann)
self.ann = None
def load(
self,
model_path: str,
fast_math: bool = True,
fp16: bool = False,
cached_network_path: str | None = None,
) -> int:
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
if not exists(model_path):
raise ValueError("model_path must point to an existing file!")
save_cached_network = False
if cached_network_path is not None and not exists(cached_network_path):
save_cached_network = True
# create empty model cache file
open(cached_network_path, "a").close()
net_id: int = libann.load(
self.ann,
model_path.encode(),
fast_math,
fp16,
save_cached_network,
cached_network_path.encode() if cached_network_path is not None else None,
)
if net_id < 0:
raise ValueError("Cannot load model!")
self.input_shapes[net_id] = tuple(
self.shape(net_id, input=True, index=i) for i in range(self.tensors(net_id, input=True))
)
self.output_shapes[net_id] = tuple(
self.shape(net_id, input=False, index=i) for i in range(self.tensors(net_id, input=False))
)
return net_id
def unload(self, network_id: int) -> None:
libann.unload(self.ann, network_id)
del self.output_shapes[network_id]
def execute(self, network_id: int, input_tensors: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
if not isinstance(input_tensors, list):
raise ValueError("input_tensors needs to be a list!")
net_input_shapes = self.input_shapes[network_id]
if len(input_tensors) != len(net_input_shapes):
raise ValueError(f"input_tensors lengths {len(input_tensors)} != network inputs {len(net_input_shapes)}")
for net_input_shape, input_tensor in zip(net_input_shapes, input_tensors):
if net_input_shape != input_tensor.shape:
raise ValueError(f"input_tensor shape {input_tensor.shape} != network input shape {net_input_shape}")
if not input_tensor.flags.c_contiguous:
raise ValueError("input_tensors must be c_contiguous numpy ndarrays")
output_tensors: list[NDArray[np.float32]] = [
np.ndarray(s, dtype=np.float32) for s in self.output_shapes[network_id]
]
input_type = c_void_p * len(input_tensors)
inputs = input_type(*[t.ctypes.data_as(c_void_p) for t in input_tensors])
output_type = c_void_p * len(output_tensors)
outputs = output_type(*[t.ctypes.data_as(c_void_p) for t in output_tensors])
libann.execute(self.ann, network_id, inputs, outputs)
return output_tensors
def shape(self, network_id: int, input: bool = False, index: int = 0) -> tuple[int]:
s = libann.shape(self.ann, network_id, input, index)
a = []
while s != 0:
a.append(s & 0xFFFF)
s >>= 16
return tuple(a)
def tensors(self, network_id: int, input: bool = False) -> int:
tensors: int = libann.tensors(self.ann, network_id, input)
return tensors

View file

@ -0,0 +1,147 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import onnxruntime as ort
from numpy.typing import NDArray
from immich_ml.models.constants import SUPPORTED_PROVIDERS
from immich_ml.schemas import SessionNode
from ..config import log, settings
class OrtSession:
session: ort.InferenceSession
def __init__(
self,
model_path: Path | str,
providers: list[str] | None = None,
provider_options: list[dict[str, Any]] | None = None,
sess_options: ort.SessionOptions | None = None,
):
self.model_path = Path(model_path)
self.providers = providers if providers is not None else self._providers_default
self.provider_options = provider_options if provider_options is not None else self._provider_options_default
self.sess_options = sess_options if sess_options is not None else self._sess_options_default
self.session = ort.InferenceSession(
self.model_path.as_posix(),
providers=self.providers,
provider_options=self.provider_options,
sess_options=self.sess_options,
)
def get_inputs(self) -> list[SessionNode]:
inputs: list[SessionNode] = self.session.get_inputs()
return inputs
def get_outputs(self) -> list[SessionNode]:
outputs: list[SessionNode] = self.session.get_outputs()
return outputs
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
return outputs
@property
def providers(self) -> list[str]:
return self._providers
@providers.setter
def providers(self, providers: list[str]) -> None:
log.info(f"Setting execution providers to {providers}, in descending order of preference")
self._providers = providers
@property
def _providers_default(self) -> list[str]:
available_providers = set(ort.get_available_providers())
log.debug(f"Available ORT providers: {available_providers}")
if (openvino := "OpenVINOExecutionProvider") in available_providers:
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
log.debug(f"Available OpenVINO devices: {device_ids}")
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
if not gpu_devices:
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
available_providers.remove(openvino)
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
@property
def provider_options(self) -> list[dict[str, Any]]:
return self._provider_options
@provider_options.setter
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
log.debug(f"Setting execution provider options to {provider_options}")
self._provider_options = provider_options
@property
def _provider_options_default(self) -> list[dict[str, Any]]:
provider_options = []
for provider in self.providers:
match provider:
case "CPUExecutionProvider":
options = {"arena_extend_strategy": "kSameAsRequested"}
case "CUDAExecutionProvider" | "ROCMExecutionProvider":
options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id}
case "OpenVINOExecutionProvider":
openvino_dir = self.model_path.parent / "openvino"
device = f"GPU.{settings.device_id}"
options = {
"device_type": device,
"precision": settings.openvino_precision.value,
"cache_dir": openvino_dir.as_posix(),
}
case "CoreMLExecutionProvider":
options = {
"ModelFormat": "MLProgram",
"MLComputeUnits": "ALL",
"SpecializationStrategy": "FastPrediction",
"AllowLowPrecisionAccumulationOnGPU": "1",
"ModelCacheDirectory": (self.model_path.parent / "coreml").as_posix(),
}
case _:
options = {}
provider_options.append(options)
return provider_options
@property
def sess_options(self) -> ort.SessionOptions:
return self._sess_options
@sess_options.setter
def sess_options(self, sess_options: ort.SessionOptions) -> None:
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
self._sess_options = sess_options
@property
def _sess_options_default(self) -> ort.SessionOptions:
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = settings.model_arena
# avoid thread contention between models
if settings.model_inter_op_threads > 0:
sess_options.inter_op_num_threads = settings.model_inter_op_threads
# these defaults work well for CPU, but bottleneck GPU
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.inter_op_num_threads = 1
if settings.model_intra_op_threads > 0:
sess_options.intra_op_num_threads = settings.model_intra_op_threads
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
sess_options.intra_op_num_threads = 2
if sess_options.inter_op_num_threads > 1:
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
return sess_options

View file

@ -0,0 +1,76 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, NamedTuple
import numpy as np
from numpy.typing import NDArray
from immich_ml.config import log, settings
from immich_ml.schemas import SessionNode
from .rknnpool import RknnPoolExecutor, is_available, soc_name
is_available = is_available and settings.rknn
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")
return outputs
input_output_mapping: dict[str, dict[str, Any]] = {
"detection": {
"input": {"norm_tensor:0": (1, 3, 640, 640)},
"output": {
"norm_tensor:1": (12800, 1),
"norm_tensor:2": (3200, 1),
"norm_tensor:3": (800, 1),
"norm_tensor:4": (12800, 4),
"norm_tensor:5": (3200, 4),
"norm_tensor:6": (800, 4),
"norm_tensor:7": (12800, 10),
"norm_tensor:8": (3200, 10),
"norm_tensor:9": (800, 10),
},
},
"recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}},
}
class RknnSession:
def __init__(self, model_path: Path) -> None:
self.model_type = "detection" if "detection" in model_path.parts else "recognition"
self.tpe = settings.rknn_threads
log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.")
self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference)
log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.")
def get_inputs(self) -> list[SessionNode]:
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()]
def get_outputs(self) -> list[SessionNode]:
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()]
def run(
self,
output_names: list[str] | None,
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
run_options: Any = None,
) -> list[NDArray[np.float32]]:
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
self.rknnpool.put(input_data)
res = self.rknnpool.get()
if res is None:
raise RuntimeError("RKNN inference failed!")
return res
class RknnNode(NamedTuple):
name: str | None
shape: tuple[int, ...]
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]

View file

@ -0,0 +1,91 @@
# This code is from leafqycc/rknn-multi-threaded
# Following Apache License 2.0
import logging
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from queue import Queue
from typing import Callable
import numpy as np
from numpy.typing import NDArray
from immich_ml.config import log
from immich_ml.models.constants import RKNN_COREMASK_SUPPORTED_SOCS, RKNN_SUPPORTED_SOCS
def get_soc(device_tree_path: Path | str) -> str | None:
try:
with Path(device_tree_path).open() as f:
device_compatible_str = f.read()
for soc in RKNN_SUPPORTED_SOCS:
if soc in device_compatible_str:
return soc
log.warning("Device is not supported for RKNN")
except OSError as e:
log.warning(f"Could not read {device_tree_path}. Reason: %s", e)
return None
soc_name = None
is_available = False
try:
from rknnlite.api import RKNNLite
soc_name = get_soc("/proc/device-tree/compatible")
is_available = soc_name is not None
except ImportError:
log.debug("RKNN is not available")
def init_rknn(model_path: str) -> "RKNNLite":
if not is_available:
raise RuntimeError("rknn is not available!")
rknn_lite = RKNNLite()
rknn_lite.rknn_log.logger.setLevel(logging.ERROR)
ret = rknn_lite.load_rknn(model_path)
if ret != 0:
raise RuntimeError("Failed to load RKNN model")
if soc_name in RKNN_COREMASK_SUPPORTED_SOCS:
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
else:
ret = rknn_lite.init_runtime() # Please do not set this parameter on other platforms.
if ret != 0:
raise RuntimeError("Failed to initialize RKNN runtime environment")
return rknn_lite
class RknnPoolExecutor:
def __init__(
self,
model_path: str,
tpes: int,
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
) -> None:
self.tpes = tpes
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
self.pool = ThreadPoolExecutor(max_workers=tpes)
self.func = func
self.num = 0
def put(self, inputs: list[NDArray[np.float32]]) -> None:
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
self.num += 1
def get(self) -> list[NDArray[np.float32]] | None:
if self.queue.empty():
return None
fut = self.queue.get()
return fut.result()
def release(self) -> None:
self.pool.shutdown()
for rknn_lite in self.rknn_pool:
rknn_lite.release()
def __del__(self) -> None:
self.release()