Source Code added
This commit is contained in:
parent
800376eafd
commit
9efa9bc6dd
3912 changed files with 754770 additions and 2 deletions
0
machine-learning/immich_ml/__init__.py
Normal file
0
machine-learning/immich_ml/__init__.py
Normal file
57
machine-learning/immich_ml/__main__.py
Normal file
57
machine-learning/immich_ml/__main__.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
from ipaddress import ip_address
|
||||
from pathlib import Path
|
||||
|
||||
from .config import log, non_prefixed_settings, settings
|
||||
|
||||
if source_ref := os.getenv("IMMICH_SOURCE_REF"):
|
||||
log.info(f"Initializing Immich ML [{source_ref}]")
|
||||
else:
|
||||
log.info("Initializing Immich ML")
|
||||
|
||||
module_dir = Path(__file__).parent
|
||||
|
||||
|
||||
def is_ipv6(host: str) -> bool:
|
||||
try:
|
||||
return ip_address(host).version == 6
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
bind_host = non_prefixed_settings.immich_host
|
||||
if is_ipv6(bind_host):
|
||||
bind_host = f"[{bind_host}]"
|
||||
bind_address = f"{bind_host}:{non_prefixed_settings.immich_port}"
|
||||
|
||||
try:
|
||||
with subprocess.Popen(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"gunicorn",
|
||||
"immich_ml.main:app",
|
||||
"-k",
|
||||
"immich_ml.config.CustomUvicornWorker",
|
||||
"-c",
|
||||
module_dir / "gunicorn_conf.py",
|
||||
"-b",
|
||||
bind_address,
|
||||
"-w",
|
||||
str(settings.workers),
|
||||
"-t",
|
||||
str(settings.worker_timeout),
|
||||
"--log-config-json",
|
||||
module_dir / "log_conf.json",
|
||||
"--keep-alive",
|
||||
str(settings.http_keepalive_timeout_s),
|
||||
"--graceful-timeout",
|
||||
"10",
|
||||
],
|
||||
) as cmd:
|
||||
cmd.wait()
|
||||
except KeyboardInterrupt:
|
||||
cmd.send_signal(signal.SIGINT)
|
||||
exit(cmd.returncode)
|
||||
165
machine-learning/immich_ml/config.py
Normal file
165
machine-learning/immich_ml/config.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from socket import socket
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from uvicorn import Server
|
||||
from uvicorn.workers import UvicornWorker
|
||||
|
||||
from .schemas import ModelPrecision
|
||||
|
||||
|
||||
class ClipSettings(BaseModel):
|
||||
textual: str | None = None
|
||||
visual: str | None = None
|
||||
|
||||
|
||||
class FacialRecognitionSettings(BaseModel):
|
||||
recognition: str | None = None
|
||||
detection: str | None = None
|
||||
|
||||
|
||||
class OcrSettings(BaseModel):
|
||||
recognition: str | None = None
|
||||
detection: str | None = None
|
||||
|
||||
|
||||
class PreloadModelData(BaseModel):
|
||||
clip_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__CLIP", None)
|
||||
facial_recognition_fallback: str | None = os.getenv("MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION", None)
|
||||
if clip_fallback is not None:
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL"] = clip_fallback
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__CLIP__VISUAL"] = clip_fallback
|
||||
del os.environ["MACHINE_LEARNING_PRELOAD__CLIP"]
|
||||
if facial_recognition_fallback is not None:
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION"] = facial_recognition_fallback
|
||||
os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION"] = facial_recognition_fallback
|
||||
del os.environ["MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION"]
|
||||
clip: ClipSettings = ClipSettings()
|
||||
facial_recognition: FacialRecognitionSettings = FacialRecognitionSettings()
|
||||
ocr: OcrSettings = OcrSettings()
|
||||
|
||||
|
||||
class MaxBatchSize(BaseModel):
|
||||
facial_recognition: int | None = None
|
||||
text_recognition: int | None = None
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_prefix="MACHINE_LEARNING_",
|
||||
case_sensitive=False,
|
||||
env_nested_delimiter="__",
|
||||
protected_namespaces=("settings_",),
|
||||
)
|
||||
|
||||
cache_folder: Path = (Path.home() / ".cache" / "immich_ml").resolve()
|
||||
model_ttl: int = 300
|
||||
model_ttl_poll_s: int = 10
|
||||
workers: int = 1
|
||||
worker_timeout: int = 300
|
||||
http_keepalive_timeout_s: int = 2
|
||||
test_full: bool = False
|
||||
request_threads: int = os.cpu_count() or 4
|
||||
model_inter_op_threads: int = 0
|
||||
model_intra_op_threads: int = 0
|
||||
model_arena: bool = True
|
||||
ann: bool = True
|
||||
ann_fp16_turbo: bool = False
|
||||
ann_tuning_level: int = 2
|
||||
rknn: bool = True
|
||||
rknn_threads: int = 1
|
||||
preload: PreloadModelData | None = None
|
||||
max_batch_size: MaxBatchSize | None = None
|
||||
openvino_precision: ModelPrecision = ModelPrecision.FP32
|
||||
|
||||
@property
|
||||
def device_id(self) -> str:
|
||||
return os.environ.get("MACHINE_LEARNING_DEVICE_ID", "0")
|
||||
|
||||
|
||||
class NonPrefixedSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(case_sensitive=False)
|
||||
|
||||
immich_host: str = "[::]"
|
||||
immich_port: int = 3003
|
||||
immich_log_level: str = "info"
|
||||
no_color: bool = False
|
||||
|
||||
|
||||
_clean_name = str.maketrans(":\\/", "___", ".")
|
||||
|
||||
|
||||
def clean_name(model_name: str) -> str:
|
||||
return model_name.split("/")[-1].translate(_clean_name)
|
||||
|
||||
|
||||
LOG_LEVELS: dict[str, int] = {
|
||||
"critical": logging.ERROR,
|
||||
"error": logging.ERROR,
|
||||
"warning": logging.WARNING,
|
||||
"warn": logging.WARNING,
|
||||
"info": logging.INFO,
|
||||
"log": logging.INFO,
|
||||
"debug": logging.DEBUG,
|
||||
"verbose": logging.DEBUG,
|
||||
}
|
||||
|
||||
settings = Settings()
|
||||
non_prefixed_settings = NonPrefixedSettings()
|
||||
|
||||
LOG_LEVEL = LOG_LEVELS.get(non_prefixed_settings.immich_log_level.lower(), logging.INFO)
|
||||
|
||||
|
||||
class CustomRichHandler(RichHandler):
|
||||
def __init__(self) -> None:
|
||||
console = Console(color_system="standard", no_color=non_prefixed_settings.no_color)
|
||||
self.excluded = ["uvicorn", "starlette", "fastapi"]
|
||||
super().__init__(
|
||||
show_path=False,
|
||||
omit_repeated_times=False,
|
||||
console=console,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_suppress=[*self.excluded, concurrent.futures],
|
||||
tracebacks_show_locals=LOG_LEVEL == logging.DEBUG,
|
||||
)
|
||||
|
||||
# hack to exclude certain modules from rich tracebacks
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
if record.exc_info is not None:
|
||||
tb = record.exc_info[2]
|
||||
while tb is not None:
|
||||
if any(excluded in tb.tb_frame.f_code.co_filename for excluded in self.excluded):
|
||||
tb.tb_frame.f_locals["_rich_traceback_omit"] = True
|
||||
tb = tb.tb_next
|
||||
|
||||
return super().emit(record)
|
||||
|
||||
|
||||
log = logging.getLogger("ml.log")
|
||||
log.setLevel(LOG_LEVEL)
|
||||
|
||||
|
||||
# patches this issue https://github.com/encode/uvicorn/discussions/1803
|
||||
class CustomUvicornServer(Server):
|
||||
async def shutdown(self, sockets: list[socket] | None = None) -> None:
|
||||
for sock in sockets or []:
|
||||
sock.close()
|
||||
await super().shutdown()
|
||||
|
||||
|
||||
class CustomUvicornWorker(UvicornWorker):
|
||||
async def _serve(self) -> None:
|
||||
self.config.app = self.wsgi
|
||||
server = CustomUvicornServer(config=self.config)
|
||||
self._install_sigquit_handler()
|
||||
await server.serve(sockets=self.sockets)
|
||||
if not server.started:
|
||||
sys.exit(Arbiter.WORKER_BOOT_ERROR)
|
||||
12
machine-learning/immich_ml/gunicorn_conf.py
Normal file
12
machine-learning/immich_ml/gunicorn_conf.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import os
|
||||
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.workers.base import Worker
|
||||
|
||||
device_ids = os.environ.get("MACHINE_LEARNING_DEVICE_IDS", "0").replace(" ", "").split(",")
|
||||
env = os.environ
|
||||
|
||||
|
||||
# Round-robin device assignment for each worker
|
||||
def pre_fork(arbiter: Arbiter, _: Worker) -> None:
|
||||
env["MACHINE_LEARNING_DEVICE_ID"] = device_ids[len(arbiter.WORKERS) % len(device_ids)]
|
||||
21
machine-learning/immich_ml/log_conf.json
Normal file
21
machine-learning/immich_ml/log_conf.json
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
{
|
||||
"version": 1,
|
||||
"disable_existing_loggers": false,
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "immich_ml.config.CustomRichHandler"
|
||||
}
|
||||
},
|
||||
"loggers": {
|
||||
"gunicorn.error": {
|
||||
"handlers": [
|
||||
"console"
|
||||
]
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"handlers": [
|
||||
"console"
|
||||
]
|
||||
}
|
||||
}
|
||||
272
machine-learning/immich_ml/main.py
Normal file
272
machine-learning/immich_ml/main.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Any, AsyncGenerator, Callable, Iterator
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import orjson
|
||||
from fastapi import Depends, FastAPI, File, Form, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
||||
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidProtobuf, NoSuchFile
|
||||
from PIL.Image import Image
|
||||
from pydantic import ValidationError
|
||||
from starlette.formparsers import MultiPartParser
|
||||
|
||||
from immich_ml.models import get_model_deps
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_pil
|
||||
|
||||
from .config import PreloadModelData, log, settings
|
||||
from .models.cache import ModelCache
|
||||
from .schemas import (
|
||||
InferenceEntries,
|
||||
InferenceEntry,
|
||||
InferenceResponse,
|
||||
ModelFormat,
|
||||
ModelIdentity,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
PipelineRequest,
|
||||
T,
|
||||
)
|
||||
|
||||
MultiPartParser.spool_max_size = 2**26 # spools to disk if payload is 64 MiB or larger
|
||||
|
||||
model_cache = ModelCache(revalidate=settings.model_ttl > 0)
|
||||
thread_pool: ThreadPoolExecutor | None = None
|
||||
lock = threading.Lock()
|
||||
active_requests = 0
|
||||
last_called: float | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
|
||||
global thread_pool
|
||||
log.info(
|
||||
(
|
||||
"Created in-memory cache with unloading "
|
||||
f"{f'after {settings.model_ttl}s of inactivity' if settings.model_ttl > 0 else 'disabled'}."
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
if settings.request_threads > 0:
|
||||
# asyncio is a huge bottleneck for performance, so we use a thread pool to run blocking code
|
||||
thread_pool = ThreadPoolExecutor(settings.request_threads) if settings.request_threads > 0 else None
|
||||
log.info(f"Initialized request thread pool with {settings.request_threads} threads.")
|
||||
if settings.model_ttl > 0 and settings.model_ttl_poll_s > 0:
|
||||
asyncio.ensure_future(idle_shutdown_task())
|
||||
if settings.preload is not None:
|
||||
await preload_models(settings.preload)
|
||||
yield
|
||||
finally:
|
||||
log.handlers.clear()
|
||||
for model in model_cache.cache._cache.values():
|
||||
del model
|
||||
if thread_pool is not None:
|
||||
thread_pool.shutdown()
|
||||
gc.collect()
|
||||
|
||||
|
||||
async def preload_models(preload: PreloadModelData) -> None:
|
||||
log.info(f"Preloading models: clip:{preload.clip} facial_recognition:{preload.facial_recognition}")
|
||||
|
||||
async def load_models(model_string: str, model_type: ModelType, model_task: ModelTask) -> None:
|
||||
for model_name in model_string.split(","):
|
||||
model_name = model_name.strip()
|
||||
model = await model_cache.get(model_name, model_type, model_task)
|
||||
await load(model)
|
||||
|
||||
if preload.clip.textual is not None:
|
||||
await load_models(preload.clip.textual, ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
if preload.clip.visual is not None:
|
||||
await load_models(preload.clip.visual, ModelType.VISUAL, ModelTask.SEARCH)
|
||||
|
||||
if preload.facial_recognition.detection is not None:
|
||||
await load_models(
|
||||
preload.facial_recognition.detection,
|
||||
ModelType.DETECTION,
|
||||
ModelTask.FACIAL_RECOGNITION,
|
||||
)
|
||||
|
||||
if preload.facial_recognition.recognition is not None:
|
||||
await load_models(
|
||||
preload.facial_recognition.recognition,
|
||||
ModelType.RECOGNITION,
|
||||
ModelTask.FACIAL_RECOGNITION,
|
||||
)
|
||||
|
||||
if preload.ocr.detection is not None:
|
||||
await load_models(
|
||||
preload.ocr.detection,
|
||||
ModelType.DETECTION,
|
||||
ModelTask.OCR,
|
||||
)
|
||||
|
||||
if preload.ocr.recognition is not None:
|
||||
await load_models(
|
||||
preload.ocr.recognition,
|
||||
ModelType.RECOGNITION,
|
||||
ModelTask.OCR,
|
||||
)
|
||||
|
||||
if preload.clip_fallback is not None:
|
||||
log.warning(
|
||||
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__CLIP'. "
|
||||
"Use 'MACHINE_LEARNING_PRELOAD__CLIP__TEXTUAL' and "
|
||||
"'MACHINE_LEARNING_PRELOAD__CLIP__VISUAL' instead."
|
||||
)
|
||||
|
||||
if preload.facial_recognition_fallback is not None:
|
||||
log.warning(
|
||||
"Deprecated env variable: 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION'. "
|
||||
"Use 'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__DETECTION' and "
|
||||
"'MACHINE_LEARNING_PRELOAD__FACIAL_RECOGNITION__RECOGNITION' instead."
|
||||
)
|
||||
|
||||
|
||||
def update_state() -> Iterator[None]:
|
||||
global active_requests, last_called
|
||||
active_requests += 1
|
||||
last_called = time.time()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
active_requests -= 1
|
||||
|
||||
|
||||
def get_entries(entries: str = Form()) -> InferenceEntries:
|
||||
try:
|
||||
request: PipelineRequest = orjson.loads(entries)
|
||||
without_deps: list[InferenceEntry] = []
|
||||
with_deps: list[InferenceEntry] = []
|
||||
for task, types in request.items():
|
||||
for type, entry in types.items():
|
||||
parsed: InferenceEntry = {
|
||||
"name": entry["modelName"],
|
||||
"task": task,
|
||||
"type": type,
|
||||
"options": entry.get("options", {}),
|
||||
}
|
||||
dep = get_model_deps(parsed["name"], type, task)
|
||||
(with_deps if dep else without_deps).append(parsed)
|
||||
return without_deps, with_deps
|
||||
except (orjson.JSONDecodeError, ValidationError, KeyError, AttributeError) as e:
|
||||
log.error(f"Invalid request format: {e}")
|
||||
raise HTTPException(422, "Invalid request format.")
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> ORJSONResponse:
|
||||
return ORJSONResponse({"message": "Immich ML"})
|
||||
|
||||
|
||||
@app.get("/ping")
|
||||
def ping() -> PlainTextResponse:
|
||||
return PlainTextResponse("pong")
|
||||
|
||||
|
||||
@app.post("/predict", dependencies=[Depends(update_state)])
|
||||
async def predict(
|
||||
entries: InferenceEntries = Depends(get_entries),
|
||||
image: bytes | None = File(default=None),
|
||||
text: str | None = Form(default=None),
|
||||
) -> Any:
|
||||
if image is not None:
|
||||
inputs: Image | str = await run(lambda: decode_pil(image))
|
||||
elif text is not None:
|
||||
inputs = text
|
||||
else:
|
||||
raise HTTPException(400, "Either image or text must be provided")
|
||||
response = await run_inference(inputs, entries)
|
||||
return ORJSONResponse(response)
|
||||
|
||||
|
||||
async def run_inference(payload: Image | str, entries: InferenceEntries) -> InferenceResponse:
|
||||
outputs: dict[ModelIdentity, Any] = {}
|
||||
response: InferenceResponse = {}
|
||||
|
||||
async def _run_inference(entry: InferenceEntry) -> None:
|
||||
model = await model_cache.get(
|
||||
entry["name"], entry["type"], entry["task"], ttl=settings.model_ttl, **entry["options"]
|
||||
)
|
||||
inputs = [payload]
|
||||
for dep in model.depends:
|
||||
try:
|
||||
inputs.append(outputs[dep])
|
||||
except KeyError:
|
||||
message = f"Task {entry['task']} of type {entry['type']} depends on output of {dep}"
|
||||
raise HTTPException(400, message)
|
||||
model = await load(model)
|
||||
output = await run(model.predict, *inputs, **entry["options"])
|
||||
outputs[model.identity] = output
|
||||
response[entry["task"]] = output
|
||||
|
||||
without_deps, with_deps = entries
|
||||
await asyncio.gather(*[_run_inference(entry) for entry in without_deps])
|
||||
if with_deps:
|
||||
await asyncio.gather(*[_run_inference(entry) for entry in with_deps])
|
||||
if isinstance(payload, Image):
|
||||
response["imageHeight"], response["imageWidth"] = payload.height, payload.width
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def run(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
|
||||
if thread_pool is None:
|
||||
return func(*args, **kwargs)
|
||||
partial_func = partial(func, *args, **kwargs)
|
||||
return await asyncio.get_running_loop().run_in_executor(thread_pool, partial_func)
|
||||
|
||||
|
||||
async def load(model: InferenceModel) -> InferenceModel:
|
||||
if model.loaded:
|
||||
return model
|
||||
|
||||
def _load(model: InferenceModel) -> InferenceModel:
|
||||
if model.load_attempts > 1:
|
||||
raise HTTPException(500, f"Failed to load model '{model.model_name}'")
|
||||
with lock:
|
||||
try:
|
||||
model.load()
|
||||
except FileNotFoundError as e:
|
||||
if model.model_format == ModelFormat.ONNX:
|
||||
raise e
|
||||
log.warning(
|
||||
f"{model.model_format.upper()} is available, but model '{model.model_name}' does not support it.",
|
||||
exc_info=e,
|
||||
)
|
||||
model.model_format = ModelFormat.ONNX
|
||||
model.load()
|
||||
return model
|
||||
|
||||
try:
|
||||
return await run(_load, model)
|
||||
except (OSError, InvalidProtobuf, BadZipFile, NoSuchFile):
|
||||
log.warning(f"Failed to load {model.model_type.replace('_', ' ')} model '{model.model_name}'. Clearing cache.")
|
||||
model.clear_cache()
|
||||
return await run(_load, model)
|
||||
|
||||
|
||||
async def idle_shutdown_task() -> None:
|
||||
while True:
|
||||
if (
|
||||
last_called is not None
|
||||
and not active_requests
|
||||
and not lock.locked()
|
||||
and time.time() - last_called > settings.model_ttl
|
||||
):
|
||||
log.info("Shutting down due to inactivity.")
|
||||
os.kill(os.getpid(), signal.SIGINT)
|
||||
break
|
||||
await asyncio.sleep(settings.model_ttl_poll_s)
|
||||
48
machine-learning/immich_ml/models/__init__.py
Normal file
48
machine-learning/immich_ml/models/__init__.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from typing import Any
|
||||
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.clip.textual import MClipTextualEncoder, OpenClipTextualEncoder
|
||||
from immich_ml.models.clip.visual import OpenClipVisualEncoder
|
||||
from immich_ml.models.ocr.detection import TextDetector
|
||||
from immich_ml.models.ocr.recognition import TextRecognizer
|
||||
from immich_ml.schemas import ModelSource, ModelTask, ModelType
|
||||
|
||||
from .constants import get_model_source
|
||||
from .facial_recognition.detection import FaceDetector
|
||||
from .facial_recognition.recognition import FaceRecognizer
|
||||
|
||||
|
||||
def get_model_class(model_name: str, model_type: ModelType, model_task: ModelTask) -> type[InferenceModel]:
|
||||
source = get_model_source(model_name)
|
||||
match source, model_type, model_task:
|
||||
case ModelSource.OPENCLIP | ModelSource.MCLIP, ModelType.VISUAL, ModelTask.SEARCH:
|
||||
return OpenClipVisualEncoder
|
||||
|
||||
case ModelSource.OPENCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||
return OpenClipTextualEncoder
|
||||
|
||||
case ModelSource.MCLIP, ModelType.TEXTUAL, ModelTask.SEARCH:
|
||||
return MClipTextualEncoder
|
||||
|
||||
case ModelSource.INSIGHTFACE, ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION:
|
||||
return FaceDetector
|
||||
|
||||
case ModelSource.INSIGHTFACE, ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION:
|
||||
return FaceRecognizer
|
||||
|
||||
case ModelSource.PADDLE, ModelType.DETECTION, ModelTask.OCR:
|
||||
return TextDetector
|
||||
|
||||
case ModelSource.PADDLE, ModelType.RECOGNITION, ModelTask.OCR:
|
||||
return TextRecognizer
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown model combination: {source}, {model_type}, {model_task}")
|
||||
|
||||
|
||||
def from_model_type(model_name: str, model_type: ModelType, model_task: ModelTask, **kwargs: Any) -> InferenceModel:
|
||||
return get_model_class(model_name, model_type, model_task)(model_name, **kwargs)
|
||||
|
||||
|
||||
def get_model_deps(model_name: str, model_type: ModelType, model_task: ModelTask) -> list[tuple[ModelType, ModelTask]]:
|
||||
return get_model_class(model_name, model_type, model_task).depends
|
||||
176
machine-learning/immich_ml/models/base.py
Normal file
176
machine-learning/immich_ml/models/base.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import immich_ml.sessions.ann.loader
|
||||
import immich_ml.sessions.rknn as rknn
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from ..config import clean_name, log, settings
|
||||
from ..schemas import ModelFormat, ModelIdentity, ModelSession, ModelTask, ModelType
|
||||
from ..sessions.ann import AnnSession
|
||||
|
||||
|
||||
class InferenceModel(ABC):
|
||||
depends: ClassVar[list[ModelIdentity]]
|
||||
identity: ClassVar[ModelIdentity]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
cache_dir: Path | str | None = None,
|
||||
model_format: ModelFormat | None = None,
|
||||
session: ModelSession | None = None,
|
||||
**model_kwargs: Any,
|
||||
) -> None:
|
||||
self.loaded = session is not None
|
||||
self.load_attempts = 0
|
||||
self.model_name = clean_name(model_name)
|
||||
self.cache_dir = Path(cache_dir) if cache_dir is not None else self._cache_dir_default
|
||||
self.model_format = model_format if model_format is not None else self._model_format_default
|
||||
if session is not None:
|
||||
self.session = session
|
||||
|
||||
def download(self) -> None:
|
||||
if not self.cached:
|
||||
model_type = self.model_type.replace("-", " ")
|
||||
log.info(f"Downloading {model_type} model '{self.model_name}' to {self.model_path}. This may take a while.")
|
||||
self._download()
|
||||
|
||||
def load(self) -> None:
|
||||
if self.loaded:
|
||||
return
|
||||
self.load_attempts += 1
|
||||
|
||||
self.download()
|
||||
attempt = f"Attempt #{self.load_attempts} to load" if self.load_attempts > 1 else "Loading"
|
||||
log.info(f"{attempt} {self.model_type.replace('-', ' ')} model '{self.model_name}' to memory")
|
||||
self.session = self._load()
|
||||
self.loaded = True
|
||||
|
||||
def predict(self, *inputs: Any, **model_kwargs: Any) -> Any:
|
||||
self.load()
|
||||
if model_kwargs:
|
||||
self.configure(**model_kwargs)
|
||||
return self._predict(*inputs)
|
||||
|
||||
@abstractmethod
|
||||
def _predict(self, *inputs: Any, **model_kwargs: Any) -> Any: ...
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def _download(self) -> None:
|
||||
ignored_patterns: dict[ModelFormat, list[str]] = {
|
||||
ModelFormat.ONNX: ["*.armnn", "*.rknn"],
|
||||
ModelFormat.ARMNN: ["*.rknn"],
|
||||
ModelFormat.RKNN: ["*.armnn"],
|
||||
}
|
||||
|
||||
snapshot_download(
|
||||
f"immich-app/{clean_name(self.model_name)}",
|
||||
cache_dir=self.cache_dir,
|
||||
local_dir=self.cache_dir,
|
||||
ignore_patterns=ignored_patterns.get(self.model_format, []),
|
||||
)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
return self._make_session(self.model_path)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
if not self.cache_dir.exists():
|
||||
log.warning(
|
||||
f"Attempted to clear cache for model '{self.model_name}', but cache directory does not exist",
|
||||
)
|
||||
return
|
||||
if not rmtree.avoids_symlink_attacks:
|
||||
raise RuntimeError("Attempted to clear cache, but rmtree is not safe on this platform")
|
||||
|
||||
if self.cache_dir.is_dir():
|
||||
log.info(f"Cleared cache directory for model '{self.model_name}'.")
|
||||
rmtree(self.cache_dir)
|
||||
else:
|
||||
log.warning(
|
||||
(
|
||||
f"Encountered file instead of directory at cache path "
|
||||
f"for '{self.model_name}'. Removing file and replacing with a directory."
|
||||
),
|
||||
)
|
||||
self.cache_dir.unlink()
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_session(self, model_path: Path) -> ModelSession:
|
||||
if not model_path.is_file():
|
||||
raise FileNotFoundError(f"Model file not found: {model_path}")
|
||||
|
||||
match model_path.suffix:
|
||||
case ".armnn":
|
||||
session: ModelSession = AnnSession(model_path)
|
||||
case ".onnx":
|
||||
session = OrtSession(model_path)
|
||||
case ".rknn":
|
||||
session = rknn.RknnSession(model_path)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model file type: {model_path.suffix}")
|
||||
return session
|
||||
|
||||
def model_path_for_format(self, model_format: ModelFormat) -> Path:
|
||||
model_path_prefix = rknn.model_prefix if model_format == ModelFormat.RKNN else None
|
||||
if model_path_prefix:
|
||||
return self.model_dir / model_path_prefix / f"model.{model_format}"
|
||||
return self.model_dir / f"model.{model_format}"
|
||||
|
||||
@property
|
||||
def model_dir(self) -> Path:
|
||||
return self.cache_dir / self.model_type.value
|
||||
|
||||
@property
|
||||
def model_path(self) -> Path:
|
||||
return self.model_path_for_format(self.model_format)
|
||||
|
||||
@property
|
||||
def model_task(self) -> ModelTask:
|
||||
return self.identity[1]
|
||||
|
||||
@property
|
||||
def model_type(self) -> ModelType:
|
||||
return self.identity[0]
|
||||
|
||||
@property
|
||||
def cache_dir(self) -> Path:
|
||||
return self._cache_dir
|
||||
|
||||
@cache_dir.setter
|
||||
def cache_dir(self, cache_dir: Path) -> None:
|
||||
self._cache_dir = cache_dir
|
||||
|
||||
@property
|
||||
def _cache_dir_default(self) -> Path:
|
||||
return settings.cache_folder / self.model_task.value / self.model_name
|
||||
|
||||
@property
|
||||
def cached(self) -> bool:
|
||||
return self.model_path.is_file()
|
||||
|
||||
@property
|
||||
def model_format(self) -> ModelFormat:
|
||||
return self._model_format
|
||||
|
||||
@model_format.setter
|
||||
def model_format(self, model_format: ModelFormat) -> None:
|
||||
log.debug(f"Setting model format to {model_format}")
|
||||
self._model_format = model_format
|
||||
|
||||
@property
|
||||
def _model_format_default(self) -> ModelFormat:
|
||||
if rknn.is_available:
|
||||
return ModelFormat.RKNN
|
||||
elif immich_ml.sessions.ann.loader.is_available and settings.ann:
|
||||
return ModelFormat.ARMNN
|
||||
else:
|
||||
return ModelFormat.ONNX
|
||||
60
machine-learning/immich_ml/models/cache.py
Normal file
60
machine-learning/immich_ml/models/cache.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
from typing import Any
|
||||
|
||||
from aiocache.backends.memory import SimpleMemoryCache
|
||||
from aiocache.lock import OptimisticLock
|
||||
from aiocache.plugins import TimingPlugin
|
||||
|
||||
from immich_ml.models import from_model_type
|
||||
from immich_ml.models.base import InferenceModel
|
||||
|
||||
from ..schemas import ModelTask, ModelType, has_profiling
|
||||
|
||||
|
||||
class ModelCache:
|
||||
"""Fetches a model from an in-memory cache, instantiating it if it's missing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
revalidate: bool = False,
|
||||
timeout: int | None = None,
|
||||
profiling: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
revalidate: Resets TTL on cache hit. Useful to keep models in memory while active. Defaults to False.
|
||||
timeout: Maximum allowed time for model to load. Disabled if None. Defaults to None.
|
||||
profiling: Collects metrics for cache operations, adding slight overhead. Defaults to False.
|
||||
"""
|
||||
|
||||
plugins = []
|
||||
|
||||
if profiling:
|
||||
plugins.append(TimingPlugin())
|
||||
|
||||
self.should_revalidate = revalidate
|
||||
|
||||
self.cache = SimpleMemoryCache(timeout=timeout, plugins=plugins, namespace=None)
|
||||
|
||||
async def get(
|
||||
self, model_name: str, model_type: ModelType, model_task: ModelTask, **model_kwargs: Any
|
||||
) -> InferenceModel:
|
||||
key = f"{model_name}{model_type}{model_task}"
|
||||
|
||||
async with OptimisticLock(self.cache, key) as lock:
|
||||
model: InferenceModel | None = await self.cache.get(key)
|
||||
if model is None:
|
||||
model = from_model_type(model_name, model_type, model_task, **model_kwargs)
|
||||
await lock.cas(model, ttl=model_kwargs.get("ttl", None))
|
||||
elif self.should_revalidate:
|
||||
await self.revalidate(key, model_kwargs.get("ttl", None))
|
||||
return model
|
||||
|
||||
async def get_profiling(self) -> dict[str, float] | None:
|
||||
if not has_profiling(self.cache):
|
||||
return None
|
||||
|
||||
return self.cache.profiling
|
||||
|
||||
async def revalidate(self, key: str, ttl: int | None) -> None:
|
||||
if ttl is not None and key in self.cache._handlers:
|
||||
await self.cache.expire(key, ttl)
|
||||
120
machine-learning/immich_ml/models/clip/textual.py
Normal file
120
machine-learning/immich_ml/models/clip/textual.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from tokenizers import Encoding, Tokenizer
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.constants import WEBLATE_TO_FLORES200
|
||||
from immich_ml.models.transforms import clean_text, serialize_np_array
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class BaseCLIPTextualEncoder(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.TEXTUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: str, language: str | None = None) -> str:
|
||||
tokens = self.tokenize(inputs, language=language)
|
||||
res: NDArray[np.float32] = self.session.run(None, tokens)[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = super()._load()
|
||||
log.debug(f"Loading tokenizer for CLIP model '{self.model_name}'")
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
tokenizer_kwargs: dict[str, Any] | None = self.text_cfg.get("tokenizer_kwargs")
|
||||
self.canonicalize = tokenizer_kwargs is not None and tokenizer_kwargs.get("clean") == "canonicalize"
|
||||
self.is_nllb = self.model_name.startswith("nllb")
|
||||
log.debug(f"Loaded tokenizer for CLIP model '{self.model_name}'")
|
||||
|
||||
return session
|
||||
|
||||
@abstractmethod
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def model_cfg_path(self) -> Path:
|
||||
return self.cache_dir / "config.json"
|
||||
|
||||
@property
|
||||
def tokenizer_file_path(self) -> Path:
|
||||
return self.model_dir / "tokenizer.json"
|
||||
|
||||
@property
|
||||
def tokenizer_cfg_path(self) -> Path:
|
||||
return self.model_dir / "tokenizer_config.json"
|
||||
|
||||
@cached_property
|
||||
def model_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@property
|
||||
def text_cfg(self) -> dict[str, Any]:
|
||||
text_cfg: dict[str, Any] = self.model_cfg["text_cfg"]
|
||||
return text_cfg
|
||||
|
||||
@cached_property
|
||||
def tokenizer_file(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer file for CLIP model '{self.model_name}'")
|
||||
tokenizer_file: dict[str, Any] = json.load(self.tokenizer_file_path.open())
|
||||
log.debug(f"Loaded tokenizer file for CLIP model '{self.model_name}'")
|
||||
return tokenizer_file
|
||||
|
||||
@cached_property
|
||||
def tokenizer_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading tokenizer config for CLIP model '{self.model_name}'")
|
||||
tokenizer_cfg: dict[str, Any] = json.load(self.tokenizer_cfg_path.open())
|
||||
log.debug(f"Loaded tokenizer config for CLIP model '{self.model_name}'")
|
||||
return tokenizer_cfg
|
||||
|
||||
|
||||
class OpenClipTextualEncoder(BaseCLIPTextualEncoder):
|
||||
def _load_tokenizer(self) -> Tokenizer:
|
||||
context_length: int = self.text_cfg.get("context_length", 77)
|
||||
pad_token: str = self.tokenizer_cfg["pad_token"]
|
||||
|
||||
tokenizer: Tokenizer = Tokenizer.from_file(self.tokenizer_file_path.as_posix())
|
||||
|
||||
pad_id: int = tokenizer.token_to_id(pad_token)
|
||||
tokenizer.enable_padding(length=context_length, pad_token=pad_token, pad_id=pad_id)
|
||||
tokenizer.enable_truncation(max_length=context_length)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
if self.is_nllb and language is not None:
|
||||
flores_code = WEBLATE_TO_FLORES200.get(language)
|
||||
if flores_code is None:
|
||||
no_country = language.split("-")[0]
|
||||
flores_code = WEBLATE_TO_FLORES200.get(no_country)
|
||||
if flores_code is None:
|
||||
log.warning(f"Language '{language}' not found, defaulting to 'en'")
|
||||
flores_code = "eng_Latn"
|
||||
text = f"{flores_code}{text}"
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {"text": np.array([tokens.ids], dtype=np.int32)}
|
||||
|
||||
|
||||
class MClipTextualEncoder(OpenClipTextualEncoder):
|
||||
def tokenize(self, text: str, language: str | None = None) -> dict[str, NDArray[np.int32]]:
|
||||
text = clean_text(text, canonicalize=self.canonicalize)
|
||||
tokens: Encoding = self.tokenizer.encode(text)
|
||||
return {
|
||||
"input_ids": np.array([tokens.ids], dtype=np.int32),
|
||||
"attention_mask": np.array([tokens.attention_mask], dtype=np.int32),
|
||||
}
|
||||
77
machine-learning/immich_ml/models/clip/visual.py
Normal file
77
machine-learning/immich_ml/models/clip/visual.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import (
|
||||
crop_pil,
|
||||
decode_pil,
|
||||
get_pil_resampling,
|
||||
normalize,
|
||||
resize_pil,
|
||||
serialize_np_array,
|
||||
to_numpy,
|
||||
)
|
||||
from immich_ml.schemas import ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class BaseCLIPVisualEncoder(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.VISUAL, ModelTask.SEARCH)
|
||||
|
||||
def _predict(self, inputs: Image.Image | bytes) -> str:
|
||||
image = decode_pil(inputs)
|
||||
res: NDArray[np.float32] = self.session.run(None, self.transform(image))[0][0]
|
||||
return serialize_np_array(res)
|
||||
|
||||
@abstractmethod
|
||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||
pass
|
||||
|
||||
@property
|
||||
def model_cfg_path(self) -> Path:
|
||||
return self.cache_dir / "config.json"
|
||||
|
||||
@property
|
||||
def preprocess_cfg_path(self) -> Path:
|
||||
return self.model_dir / "preprocess_cfg.json"
|
||||
|
||||
@cached_property
|
||||
def model_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading model config for CLIP model '{self.model_name}'")
|
||||
model_cfg: dict[str, Any] = json.load(self.model_cfg_path.open())
|
||||
log.debug(f"Loaded model config for CLIP model '{self.model_name}'")
|
||||
return model_cfg
|
||||
|
||||
@cached_property
|
||||
def preprocess_cfg(self) -> dict[str, Any]:
|
||||
log.debug(f"Loading visual preprocessing config for CLIP model '{self.model_name}'")
|
||||
preprocess_cfg: dict[str, Any] = json.load(self.preprocess_cfg_path.open())
|
||||
log.debug(f"Loaded visual preprocessing config for CLIP model '{self.model_name}'")
|
||||
return preprocess_cfg
|
||||
|
||||
|
||||
class OpenClipVisualEncoder(BaseCLIPVisualEncoder):
|
||||
def _load(self) -> ModelSession:
|
||||
size: list[int] | int = self.preprocess_cfg["size"]
|
||||
self.size = size[0] if isinstance(size, list) else size
|
||||
|
||||
self.resampling = get_pil_resampling(self.preprocess_cfg["interpolation"])
|
||||
self.mean = np.array(self.preprocess_cfg["mean"], dtype=np.float32)
|
||||
self.std = np.array(self.preprocess_cfg["std"], dtype=np.float32)
|
||||
|
||||
return super()._load()
|
||||
|
||||
def transform(self, image: Image.Image) -> dict[str, NDArray[np.float32]]:
|
||||
image = resize_pil(image, self.size)
|
||||
image = crop_pil(image, self.size)
|
||||
image_np = to_numpy(image)
|
||||
image_np = normalize(image_np, self.mean, self.std)
|
||||
return {"image": np.expand_dims(image_np.transpose(2, 0, 1), 0)}
|
||||
178
machine-learning/immich_ml/models/constants.py
Normal file
178
machine-learning/immich_ml/models/constants.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
from immich_ml.config import clean_name
|
||||
from immich_ml.schemas import ModelSource
|
||||
|
||||
_OPENCLIP_MODELS = {
|
||||
"RN101__openai",
|
||||
"RN101__yfcc15m",
|
||||
"RN50__cc12m",
|
||||
"RN50__openai",
|
||||
"RN50__yfcc15m",
|
||||
"RN50x16__openai",
|
||||
"RN50x4__openai",
|
||||
"RN50x64__openai",
|
||||
"ViT-B-16-SigLIP-256__webli",
|
||||
"ViT-B-16-SigLIP-384__webli",
|
||||
"ViT-B-16-SigLIP-512__webli",
|
||||
"ViT-B-16-SigLIP-i18n-256__webli",
|
||||
"ViT-B-16-SigLIP__webli",
|
||||
"ViT-B-16-plus-240__laion400m_e31",
|
||||
"ViT-B-16-plus-240__laion400m_e32",
|
||||
"ViT-B-16__laion400m_e31",
|
||||
"ViT-B-16__laion400m_e32",
|
||||
"ViT-B-16__openai",
|
||||
"ViT-B-32__laion2b-s34b-b79k",
|
||||
"ViT-B-32__laion2b_e16",
|
||||
"ViT-B-32__laion400m_e31",
|
||||
"ViT-B-32__laion400m_e32",
|
||||
"ViT-B-32__openai",
|
||||
"ViT-H-14-378-quickgelu__dfn5b",
|
||||
"ViT-H-14-quickgelu__dfn5b",
|
||||
"ViT-H-14__laion2b-s32b-b79k",
|
||||
"ViT-L-14-336__openai",
|
||||
"ViT-L-14-quickgelu__dfn2b",
|
||||
"ViT-L-14__laion2b-s32b-b82k",
|
||||
"ViT-L-14__laion400m_e31",
|
||||
"ViT-L-14__laion400m_e32",
|
||||
"ViT-L-14__openai",
|
||||
"ViT-L-16-SigLIP-256__webli",
|
||||
"ViT-L-16-SigLIP-384__webli",
|
||||
"ViT-SO400M-14-SigLIP-384__webli",
|
||||
"ViT-g-14__laion2b-s12b-b42k",
|
||||
"XLM-Roberta-Base-ViT-B-32__laion5b_s13b_b90k",
|
||||
"XLM-Roberta-Large-ViT-H-14__frozen_laion5b_s13b_b90k",
|
||||
"nllb-clip-base-siglip__mrl",
|
||||
"nllb-clip-base-siglip__v1",
|
||||
"nllb-clip-large-siglip__mrl",
|
||||
"nllb-clip-large-siglip__v1",
|
||||
"ViT-B-16-SigLIP2__webli",
|
||||
"ViT-B-32-SigLIP2-256__webli",
|
||||
"ViT-L-16-SigLIP2-256__webli",
|
||||
"ViT-L-16-SigLIP2-384__webli",
|
||||
"ViT-L-16-SigLIP2-512__webli",
|
||||
"ViT-SO400M-14-SigLIP2-378__webli",
|
||||
"ViT-SO400M-14-SigLIP2__webli",
|
||||
"ViT-SO400M-16-SigLIP2-256__webli",
|
||||
"ViT-SO400M-16-SigLIP2-384__webli",
|
||||
"ViT-SO400M-16-SigLIP2-512__webli",
|
||||
"ViT-gopt-16-SigLIP2-256__webli",
|
||||
"ViT-gopt-16-SigLIP2-384__webli",
|
||||
}
|
||||
|
||||
|
||||
_MCLIP_MODELS = {
|
||||
"LABSE-Vit-L-14",
|
||||
"XLM-Roberta-Large-Vit-B-16Plus",
|
||||
"XLM-Roberta-Large-Vit-B-32",
|
||||
"XLM-Roberta-Large-Vit-L-14",
|
||||
}
|
||||
|
||||
|
||||
_INSIGHTFACE_MODELS = {
|
||||
"antelopev2",
|
||||
"buffalo_s",
|
||||
"buffalo_m",
|
||||
"buffalo_l",
|
||||
}
|
||||
|
||||
|
||||
_PADDLE_MODELS = {
|
||||
"PP-OCRv5_server",
|
||||
"PP-OCRv5_mobile",
|
||||
"CH__PP-OCRv5_server",
|
||||
"CH__PP-OCRv5_mobile",
|
||||
"EL__PP-OCRv5_mobile",
|
||||
"EN__PP-OCRv5_mobile",
|
||||
"ESLAV__PP-OCRv5_mobile",
|
||||
"KOREAN__PP-OCRv5_mobile",
|
||||
"LATIN__PP-OCRv5_mobile",
|
||||
"TH__PP-OCRv5_mobile",
|
||||
}
|
||||
|
||||
SUPPORTED_PROVIDERS = [
|
||||
"CUDAExecutionProvider",
|
||||
"ROCMExecutionProvider",
|
||||
"OpenVINOExecutionProvider",
|
||||
"CoreMLExecutionProvider",
|
||||
"CPUExecutionProvider",
|
||||
]
|
||||
|
||||
RKNN_SUPPORTED_SOCS = ["rk3566", "rk3568", "rk3576", "rk3588"]
|
||||
RKNN_COREMASK_SUPPORTED_SOCS = ["rk3576", "rk3588"]
|
||||
|
||||
|
||||
WEBLATE_TO_FLORES200 = {
|
||||
"af": "afr_Latn",
|
||||
"ar": "arb_Arab",
|
||||
"az": "azj_Latn",
|
||||
"be": "bel_Cyrl",
|
||||
"bg": "bul_Cyrl",
|
||||
"ca": "cat_Latn",
|
||||
"cs": "ces_Latn",
|
||||
"da": "dan_Latn",
|
||||
"de": "deu_Latn",
|
||||
"el": "ell_Grek",
|
||||
"en": "eng_Latn",
|
||||
"es": "spa_Latn",
|
||||
"et": "est_Latn",
|
||||
"fa": "pes_Arab",
|
||||
"fi": "fin_Latn",
|
||||
"fr": "fra_Latn",
|
||||
"he": "heb_Hebr",
|
||||
"hi": "hin_Deva",
|
||||
"hr": "hrv_Latn",
|
||||
"hu": "hun_Latn",
|
||||
"hy": "hye_Armn",
|
||||
"id": "ind_Latn",
|
||||
"it": "ita_Latn",
|
||||
"ja": "jpn_Hira",
|
||||
"kmr": "kmr_Latn",
|
||||
"ko": "kor_Hang",
|
||||
"lb": "ltz_Latn",
|
||||
"lt": "lit_Latn",
|
||||
"lv": "lav_Latn",
|
||||
"mfa": "zsm_Latn",
|
||||
"mk": "mkd_Cyrl",
|
||||
"mn": "khk_Cyrl",
|
||||
"mr": "mar_Deva",
|
||||
"ms": "zsm_Latn",
|
||||
"nb-NO": "nob_Latn",
|
||||
"nn": "nno_Latn",
|
||||
"nl": "nld_Latn",
|
||||
"pl": "pol_Latn",
|
||||
"pt-BR": "por_Latn",
|
||||
"pt": "por_Latn",
|
||||
"ro": "ron_Latn",
|
||||
"ru": "rus_Cyrl",
|
||||
"sk": "slk_Latn",
|
||||
"sl": "slv_Latn",
|
||||
"sr-Cyrl": "srp_Cyrl",
|
||||
"sv": "swe_Latn",
|
||||
"ta": "tam_Taml",
|
||||
"te": "tel_Telu",
|
||||
"th": "tha_Thai",
|
||||
"tr": "tur_Latn",
|
||||
"uk": "ukr_Cyrl",
|
||||
"ur": "urd_Arab",
|
||||
"vi": "vie_Latn",
|
||||
"zh-CN": "zho_Hans",
|
||||
"zh-Hans": "zho_Hans",
|
||||
"zh-TW": "zho_Hant",
|
||||
}
|
||||
|
||||
|
||||
def get_model_source(model_name: str) -> ModelSource | None:
|
||||
cleaned_name = clean_name(model_name)
|
||||
|
||||
if cleaned_name in _INSIGHTFACE_MODELS:
|
||||
return ModelSource.INSIGHTFACE
|
||||
|
||||
if cleaned_name in _MCLIP_MODELS:
|
||||
return ModelSource.MCLIP
|
||||
|
||||
if cleaned_name in _OPENCLIP_MODELS:
|
||||
return ModelSource.OPENCLIP
|
||||
|
||||
if cleaned_name in _PADDLE_MODELS:
|
||||
return ModelSource.PADDLE
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from insightface.model_zoo import RetinaFace
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_cv2
|
||||
from immich_ml.schemas import FaceDetectionOutput, ModelSession, ModelTask, ModelType
|
||||
|
||||
|
||||
class FaceDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)
|
||||
|
||||
def __init__(self, model_name: str, min_score: float = 0.7, **model_kwargs: Any) -> None:
|
||||
self.min_score = model_kwargs.pop("minScore", min_score)
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
self.model = RetinaFace(session=session)
|
||||
self.model.prepare(ctx_id=0, det_thresh=self.min_score, input_size=(640, 640))
|
||||
|
||||
return session
|
||||
|
||||
def _predict(self, inputs: NDArray[np.uint8] | bytes) -> FaceDetectionOutput:
|
||||
inputs = decode_cv2(inputs)
|
||||
|
||||
bboxes, landmarks = self._detect(inputs)
|
||||
return {
|
||||
"boxes": bboxes[:, :4].round(),
|
||||
"scores": bboxes[:, 4],
|
||||
"landmarks": landmarks,
|
||||
}
|
||||
|
||||
def _detect(self, inputs: NDArray[np.uint8] | bytes) -> tuple[NDArray[np.float32], NDArray[np.float32]]:
|
||||
return self.model.detect(inputs) # type: ignore
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.model.det_thresh = kwargs.pop("minScore", self.model.det_thresh)
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from insightface.model_zoo import ArcFaceONNX
|
||||
from insightface.utils.face_align import norm_crop
|
||||
from numpy.typing import NDArray
|
||||
from onnx.tools.update_model_dims import update_inputs_outputs_dims
|
||||
from PIL import Image
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import decode_cv2, serialize_np_array
|
||||
from immich_ml.schemas import (
|
||||
FaceDetectionOutput,
|
||||
FacialRecognitionOutput,
|
||||
ModelFormat,
|
||||
ModelSession,
|
||||
ModelTask,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
|
||||
class FaceRecognizer(InferenceModel):
|
||||
depends = [(ModelType.DETECTION, ModelTask.FACIAL_RECOGNITION)]
|
||||
identity = (ModelType.RECOGNITION, ModelTask.FACIAL_RECOGNITION)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name, **model_kwargs)
|
||||
max_batch_size = settings.max_batch_size.facial_recognition if settings.max_batch_size else None
|
||||
self.batch_size = max_batch_size if max_batch_size else self._batch_size_default
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
session = self._make_session(self.model_path)
|
||||
if (not self.batch_size or self.batch_size > 1) and str(session.get_inputs()[0].shape[0]) != "batch":
|
||||
self._add_batch_axis(self.model_path)
|
||||
session = self._make_session(self.model_path)
|
||||
self.model = ArcFaceONNX(
|
||||
self.model_path_for_format(ModelFormat.ONNX).as_posix(),
|
||||
session=session,
|
||||
)
|
||||
return session
|
||||
|
||||
def _predict(
|
||||
self, inputs: NDArray[np.uint8] | bytes | Image.Image, faces: FaceDetectionOutput
|
||||
) -> FacialRecognitionOutput:
|
||||
if faces["boxes"].shape[0] == 0:
|
||||
return []
|
||||
inputs = decode_cv2(inputs)
|
||||
cropped_faces = self._crop(inputs, faces)
|
||||
embeddings = self._predict_batch(cropped_faces)
|
||||
return self.postprocess(faces, embeddings)
|
||||
|
||||
def _predict_batch(self, cropped_faces: list[NDArray[np.uint8]]) -> NDArray[np.float32]:
|
||||
if not self.batch_size or len(cropped_faces) <= self.batch_size:
|
||||
embeddings: NDArray[np.float32] = self.model.get_feat(cropped_faces)
|
||||
return embeddings
|
||||
|
||||
batch_embeddings: list[NDArray[np.float32]] = []
|
||||
for i in range(0, len(cropped_faces), self.batch_size):
|
||||
batch_embeddings.append(self.model.get_feat(cropped_faces[i : i + self.batch_size]))
|
||||
return np.concatenate(batch_embeddings, axis=0)
|
||||
|
||||
def postprocess(self, faces: FaceDetectionOutput, embeddings: NDArray[np.float32]) -> FacialRecognitionOutput:
|
||||
return [
|
||||
{
|
||||
"boundingBox": {"x1": x1, "y1": y1, "x2": x2, "y2": y2},
|
||||
"embedding": serialize_np_array(embedding),
|
||||
"score": score,
|
||||
}
|
||||
for (x1, y1, x2, y2), embedding, score in zip(faces["boxes"], embeddings, faces["scores"])
|
||||
]
|
||||
|
||||
def _crop(self, image: NDArray[np.uint8], faces: FaceDetectionOutput) -> list[NDArray[np.uint8]]:
|
||||
return [norm_crop(image, landmark) for landmark in faces["landmarks"]]
|
||||
|
||||
def _add_batch_axis(self, model_path: Path) -> None:
|
||||
log.debug(f"Adding batch axis to model {model_path}")
|
||||
proto = onnx.load(model_path)
|
||||
static_input_dims = [shape.dim_value for shape in proto.graph.input[0].type.tensor_type.shape.dim[1:]]
|
||||
static_output_dims = [shape.dim_value for shape in proto.graph.output[0].type.tensor_type.shape.dim[1:]]
|
||||
input_dims = {proto.graph.input[0].name: ["batch"] + static_input_dims}
|
||||
output_dims = {proto.graph.output[0].name: ["batch"] + static_output_dims}
|
||||
updated_proto = update_inputs_outputs_dims(proto, input_dims, output_dims)
|
||||
onnx.save(updated_proto, model_path)
|
||||
|
||||
@property
|
||||
def _batch_size_default(self) -> int | None:
|
||||
providers = ort.get_available_providers()
|
||||
return None if self.model_format == ModelFormat.ONNX and "OpenVINOExecutionProvider" not in providers else 1
|
||||
125
machine-learning/immich_ml/models/ocr/detection.py
Normal file
125
machine-learning/immich_ml/models/ocr/detection.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
from rapidocr.ch_ppocr_det.utils import DBPostProcess
|
||||
from rapidocr.inference_engine.base import FileInfo, InferSession
|
||||
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
|
||||
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
|
||||
from rapidocr.utils.typings import ModelType as RapidModelType
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from .schemas import TextDetectionOutput
|
||||
|
||||
|
||||
class TextDetector(InferenceModel):
|
||||
depends = []
|
||||
identity = (ModelType.DETECTION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
super().__init__(model_name.split("__")[-1], **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
self.max_resolution = 736
|
||||
self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
|
||||
self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
|
||||
self._empty: TextDetectionOutput = {
|
||||
"boxes": np.empty(0, dtype=np.float32),
|
||||
"scores": np.empty(0, dtype=np.float32),
|
||||
}
|
||||
self.postprocess = DBPostProcess(
|
||||
thresh=0.3,
|
||||
box_thresh=model_kwargs.get("minScore", 0.5),
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.6,
|
||||
use_dilation=True,
|
||||
score_mode="fast",
|
||||
)
|
||||
|
||||
def _download(self) -> None:
|
||||
model_info = InferSession.get_model_url(
|
||||
FileInfo(
|
||||
engine_type=EngineType.ONNXRUNTIME,
|
||||
ocr_version=OCRVersion.PPOCRV5,
|
||||
task_type=TaskType.DET,
|
||||
lang_type=LangDet.CH,
|
||||
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
|
||||
)
|
||||
)
|
||||
download_params = DownloadFileInput(
|
||||
file_url=model_info["model_dir"],
|
||||
sha256=model_info["SHA256"],
|
||||
save_path=self.model_path,
|
||||
logger=log,
|
||||
)
|
||||
DownloadFile.run(download_params)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
# TODO: support other runtime sessions
|
||||
return OrtSession(self.model_path)
|
||||
|
||||
# partly adapted from RapidOCR
|
||||
def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
|
||||
w, h = inputs.size
|
||||
if w < 32 or h < 32:
|
||||
return self._empty
|
||||
out = self.session.run(None, {"x": self._transform(inputs)})[0]
|
||||
boxes, scores = self.postprocess(out, (h, w))
|
||||
if len(boxes) == 0:
|
||||
return self._empty
|
||||
return {
|
||||
"boxes": self.sorted_boxes(boxes),
|
||||
"scores": np.array(scores, dtype=np.float32),
|
||||
}
|
||||
|
||||
# adapted from RapidOCR
|
||||
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
|
||||
if img.height < img.width:
|
||||
ratio = float(self.max_resolution) / img.height
|
||||
else:
|
||||
ratio = float(self.max_resolution) / img.width
|
||||
ratio = min(ratio, 1.0)
|
||||
|
||||
resize_h = int(img.height * ratio)
|
||||
resize_w = int(img.width * ratio)
|
||||
|
||||
resize_h = int(round(resize_h / 32) * 32)
|
||||
resize_w = int(round(resize_w / 32) * 32)
|
||||
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
|
||||
|
||||
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
|
||||
img_np -= self.mean
|
||||
img_np *= self.std_inv
|
||||
img_np = np.transpose(img_np, (2, 0, 1))
|
||||
return np.expand_dims(img_np, axis=0)
|
||||
|
||||
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
if len(dt_boxes) == 0:
|
||||
return dt_boxes
|
||||
|
||||
# Sort by y, then identify lines, then sort by (line, x)
|
||||
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
|
||||
sorted_y = dt_boxes[y_order, 0, 1]
|
||||
|
||||
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
|
||||
line_ids[0] = 0
|
||||
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
|
||||
|
||||
# Create composite sort key for final ordering
|
||||
# Shift line_ids by large factor, add x for tie-breaking
|
||||
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
|
||||
final_order = np.argsort(sort_key, kind="stable")
|
||||
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
|
||||
return sorted_boxes
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
if (max_resolution := kwargs.get("maxResolution")) is not None:
|
||||
self.max_resolution = max_resolution
|
||||
if (min_score := kwargs.get("minScore")) is not None:
|
||||
self.postprocess.box_thresh = min_score
|
||||
if (score_mode := kwargs.get("scoreMode")) is not None:
|
||||
self.postprocess.score_mode = score_mode
|
||||
153
machine-learning/immich_ml/models/ocr/recognition.py
Normal file
153
machine-learning/immich_ml/models/ocr/recognition.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
from rapidocr.ch_ppocr_rec import TextRecInput
|
||||
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
|
||||
from rapidocr.inference_engine.base import FileInfo, InferSession
|
||||
from rapidocr.utils.download_file import DownloadFile, DownloadFileInput
|
||||
from rapidocr.utils.typings import EngineType, LangRec, OCRVersion, TaskType
|
||||
from rapidocr.utils.typings import ModelType as RapidModelType
|
||||
from rapidocr.utils.vis_res import VisRes
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.models.base import InferenceModel
|
||||
from immich_ml.models.transforms import pil_to_cv2
|
||||
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
|
||||
from immich_ml.sessions.ort import OrtSession
|
||||
|
||||
from .schemas import OcrOptions, TextDetectionOutput, TextRecognitionOutput
|
||||
|
||||
|
||||
class TextRecognizer(InferenceModel):
|
||||
depends = [(ModelType.DETECTION, ModelTask.OCR)]
|
||||
identity = (ModelType.RECOGNITION, ModelTask.OCR)
|
||||
|
||||
def __init__(self, model_name: str, **model_kwargs: Any) -> None:
|
||||
self.language = LangRec[model_name.split("__")[0]] if "__" in model_name else LangRec.CH
|
||||
self.min_score = model_kwargs.get("minScore", 0.9)
|
||||
self._empty: TextRecognitionOutput = {
|
||||
"box": np.empty(0, dtype=np.float32),
|
||||
"boxScore": np.empty(0, dtype=np.float32),
|
||||
"text": [],
|
||||
"textScore": np.empty(0, dtype=np.float32),
|
||||
}
|
||||
VisRes.__init__ = lambda self, **kwargs: None # pyright: ignore[reportAttributeAccessIssue]
|
||||
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
|
||||
|
||||
def _download(self) -> None:
|
||||
model_info = InferSession.get_model_url(
|
||||
FileInfo(
|
||||
engine_type=EngineType.ONNXRUNTIME,
|
||||
ocr_version=OCRVersion.PPOCRV5,
|
||||
task_type=TaskType.REC,
|
||||
lang_type=self.language,
|
||||
model_type=RapidModelType.MOBILE if "mobile" in self.model_name else RapidModelType.SERVER,
|
||||
)
|
||||
)
|
||||
download_params = DownloadFileInput(
|
||||
file_url=model_info["model_dir"],
|
||||
sha256=model_info["SHA256"],
|
||||
save_path=self.model_path,
|
||||
logger=log,
|
||||
)
|
||||
DownloadFile.run(download_params)
|
||||
|
||||
def _load(self) -> ModelSession:
|
||||
# TODO: support other runtimes
|
||||
session = OrtSession(self.model_path)
|
||||
self.model = RapidTextRecognizer(
|
||||
OcrOptions(
|
||||
session=session.session,
|
||||
rec_batch_num=settings.max_batch_size.text_recognition if settings.max_batch_size is not None else 6,
|
||||
rec_img_shape=(3, 48, 320),
|
||||
lang_type=self.language,
|
||||
)
|
||||
)
|
||||
return session
|
||||
|
||||
def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
|
||||
boxes, box_scores = texts["boxes"], texts["scores"]
|
||||
if boxes.shape[0] == 0:
|
||||
return self._empty
|
||||
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
|
||||
if rec.txts is None:
|
||||
return self._empty
|
||||
|
||||
boxes[:, :, 0] /= img.width
|
||||
boxes[:, :, 1] /= img.height
|
||||
|
||||
text_scores = np.array(rec.scores)
|
||||
valid_text_score_idx = text_scores > self.min_score
|
||||
valid_score_idx_list = valid_text_score_idx.tolist()
|
||||
return {
|
||||
"box": boxes.reshape(-1, 8)[valid_text_score_idx].reshape(-1),
|
||||
"text": [rec.txts[i] for i in range(len(rec.txts)) if valid_score_idx_list[i]],
|
||||
"boxScore": box_scores[valid_text_score_idx],
|
||||
"textScore": text_scores[valid_text_score_idx],
|
||||
}
|
||||
|
||||
def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]:
|
||||
img_crop_width = np.maximum(
|
||||
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
|
||||
).astype(np.int32)
|
||||
img_crop_height = np.maximum(
|
||||
np.linalg.norm(boxes[:, 0] - boxes[:, 3], axis=1), np.linalg.norm(boxes[:, 1] - boxes[:, 2], axis=1)
|
||||
).astype(np.int32)
|
||||
pts_std = np.zeros((img_crop_width.shape[0], 4, 2), dtype=np.float32)
|
||||
pts_std[:, 1:3, 0] = img_crop_width[:, None]
|
||||
pts_std[:, 2:4, 1] = img_crop_height[:, None]
|
||||
|
||||
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1)
|
||||
all_coeffs = self._get_perspective_transform(pts_std, boxes)
|
||||
imgs: list[NDArray[np.uint8]] = []
|
||||
for coeffs, dst_size in zip(all_coeffs, img_crop_sizes):
|
||||
dst_img = img.transform(
|
||||
size=tuple(dst_size),
|
||||
method=Image.Transform.PERSPECTIVE,
|
||||
data=tuple(coeffs),
|
||||
resample=Image.Resampling.BICUBIC,
|
||||
)
|
||||
|
||||
dst_width, dst_height = dst_img.size
|
||||
if dst_height * 1.0 / dst_width >= 1.5:
|
||||
dst_img = dst_img.rotate(90, expand=True)
|
||||
imgs.append(pil_to_cv2(dst_img))
|
||||
|
||||
return imgs
|
||||
|
||||
def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]:
|
||||
N = src.shape[0]
|
||||
x, y = src[:, :, 0], src[:, :, 1]
|
||||
u, v = dst[:, :, 0], dst[:, :, 1]
|
||||
A = np.zeros((N, 8, 9), dtype=np.float32)
|
||||
|
||||
# Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
|
||||
A[:, ::2, 0] = x
|
||||
A[:, ::2, 1] = y
|
||||
A[:, ::2, 2] = 1
|
||||
A[:, ::2, 6] = -u * x
|
||||
A[:, ::2, 7] = -u * y
|
||||
A[:, ::2, 8] = -u
|
||||
|
||||
# Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
|
||||
A[:, 1::2, 3] = x
|
||||
A[:, 1::2, 4] = y
|
||||
A[:, 1::2, 5] = 1
|
||||
A[:, 1::2, 6] = -v * x
|
||||
A[:, 1::2, 7] = -v * y
|
||||
A[:, 1::2, 8] = -v
|
||||
|
||||
# Solve using SVD for all matrices at once
|
||||
_, _, Vt = np.linalg.svd(A)
|
||||
H = Vt[:, -1, :].reshape(N, 3, 3)
|
||||
H = H / H[:, 2:3, 2:3]
|
||||
|
||||
# Extract the 8 coefficients for each transformation
|
||||
return np.column_stack(
|
||||
[H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]]
|
||||
) # pyright: ignore[reportReturnType]
|
||||
|
||||
def configure(self, **kwargs: Any) -> None:
|
||||
self.min_score = kwargs.get("minScore", self.min_score)
|
||||
27
machine-learning/immich_ml/models/ocr/schemas.py
Normal file
27
machine-learning/immich_ml/models/ocr/schemas.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from rapidocr.utils.typings import EngineType, LangRec
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class TextDetectionOutput(TypedDict):
|
||||
boxes: npt.NDArray[np.float32]
|
||||
scores: npt.NDArray[np.float32]
|
||||
|
||||
|
||||
class TextRecognitionOutput(TypedDict):
|
||||
box: npt.NDArray[np.float32]
|
||||
boxScore: npt.NDArray[np.float32]
|
||||
text: Iterable[str]
|
||||
textScore: npt.NDArray[np.float32]
|
||||
|
||||
|
||||
# RapidOCR expects `engine_type`, `lang_type`, and `font_path` to be attributes
|
||||
class OcrOptions(dict[str, Any]):
|
||||
def __init__(self, lang_type: LangRec | None = None, **options: Any) -> None:
|
||||
super().__init__(**options)
|
||||
self.engine_type = EngineType.ONNXRUNTIME
|
||||
self.lang_type = lang_type
|
||||
self.font_path = None
|
||||
80
machine-learning/immich_ml/models/transforms.py
Normal file
80
machine-learning/immich_ml/models/transforms.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
import string
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import orjson
|
||||
from numpy.typing import NDArray
|
||||
from PIL import Image
|
||||
|
||||
_PIL_RESAMPLING_METHODS = {resampling.name.lower(): resampling for resampling in Image.Resampling}
|
||||
_PUNCTUATION_TRANS = str.maketrans("", "", string.punctuation)
|
||||
|
||||
|
||||
def resize_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
if img.width < img.height:
|
||||
return img.resize((size, int((img.height / img.width) * size)), resample=Image.Resampling.BICUBIC)
|
||||
else:
|
||||
return img.resize((int((img.width / img.height) * size), size), resample=Image.Resampling.BICUBIC)
|
||||
|
||||
|
||||
# https://stackoverflow.com/a/60883103
|
||||
def crop_pil(img: Image.Image, size: int) -> Image.Image:
|
||||
left = int((img.size[0] / 2) - (size / 2))
|
||||
upper = int((img.size[1] / 2) - (size / 2))
|
||||
right = left + size
|
||||
lower = upper + size
|
||||
|
||||
return img.crop((left, upper, right, lower))
|
||||
|
||||
|
||||
def to_numpy(img: Image.Image) -> NDArray[np.float32]:
|
||||
return np.asarray(img if img.mode == "RGB" else img.convert("RGB"), dtype=np.float32) / 255.0
|
||||
|
||||
|
||||
def normalize(
|
||||
img: NDArray[np.float32], mean: float | NDArray[np.float32], std: float | NDArray[np.float32]
|
||||
) -> NDArray[np.float32]:
|
||||
return (img - mean) / std
|
||||
|
||||
|
||||
def get_pil_resampling(resample: str) -> Image.Resampling:
|
||||
return _PIL_RESAMPLING_METHODS[resample.lower()]
|
||||
|
||||
|
||||
def pil_to_cv2(image: Image.Image) -> NDArray[np.uint8]:
|
||||
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) # type: ignore
|
||||
|
||||
|
||||
def decode_pil(image_bytes: bytes | IO[bytes] | Image.Image) -> Image.Image:
|
||||
if isinstance(image_bytes, Image.Image):
|
||||
return image_bytes
|
||||
image: Image.Image = Image.open(BytesIO(image_bytes) if isinstance(image_bytes, bytes) else image_bytes)
|
||||
image.load()
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def decode_cv2(image_bytes: NDArray[np.uint8] | bytes | Image.Image) -> NDArray[np.uint8]:
|
||||
match image_bytes:
|
||||
case bytes() | memoryview() | bytearray():
|
||||
return pil_to_cv2(decode_pil(image_bytes)) # pillow is much faster than cv2
|
||||
case Image.Image():
|
||||
return pil_to_cv2(image_bytes)
|
||||
case _:
|
||||
return image_bytes
|
||||
|
||||
|
||||
def clean_text(text: str, canonicalize: bool = False) -> str:
|
||||
text = " ".join(text.split())
|
||||
if canonicalize:
|
||||
text = text.translate(_PUNCTUATION_TRANS).lower()
|
||||
return text
|
||||
|
||||
|
||||
# this allows the client to use the array as a string without deserializing only to serialize back to a string
|
||||
# TODO: use this in a less invasive way
|
||||
def serialize_np_array(arr: NDArray[np.float32]) -> str:
|
||||
return orjson.dumps(arr, option=orjson.OPT_SERIALIZE_NUMPY).decode()
|
||||
122
machine-learning/immich_ml/schemas.py
Normal file
122
machine-learning/immich_ml/schemas.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Literal, Protocol, TypeGuard, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class StrEnum(str, Enum):
|
||||
value: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
class BoundingBox(TypedDict):
|
||||
x1: int
|
||||
y1: int
|
||||
x2: int
|
||||
y2: int
|
||||
|
||||
|
||||
class ModelTask(StrEnum):
|
||||
FACIAL_RECOGNITION = "facial-recognition"
|
||||
SEARCH = "clip"
|
||||
OCR = "ocr"
|
||||
|
||||
|
||||
class ModelType(StrEnum):
|
||||
DETECTION = "detection"
|
||||
RECOGNITION = "recognition"
|
||||
TEXTUAL = "textual"
|
||||
VISUAL = "visual"
|
||||
|
||||
|
||||
class ModelFormat(StrEnum):
|
||||
ARMNN = "armnn"
|
||||
ONNX = "onnx"
|
||||
RKNN = "rknn"
|
||||
|
||||
|
||||
class ModelSource(StrEnum):
|
||||
INSIGHTFACE = "insightface"
|
||||
MCLIP = "mclip"
|
||||
OPENCLIP = "openclip"
|
||||
PADDLE = "paddle"
|
||||
|
||||
|
||||
class ModelPrecision(StrEnum):
|
||||
FP16 = "FP16"
|
||||
FP32 = "FP32"
|
||||
|
||||
|
||||
ModelIdentity = tuple[ModelType, ModelTask]
|
||||
|
||||
|
||||
class SessionNode(Protocol):
|
||||
@property
|
||||
def name(self) -> str | None: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
|
||||
|
||||
class ModelSession(Protocol):
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, npt.NDArray[np.float32]] | dict[str, npt.NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[npt.NDArray[np.float32]]: ...
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]: ...
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]: ...
|
||||
|
||||
|
||||
class HasProfiling(Protocol):
|
||||
profiling: dict[str, float]
|
||||
|
||||
|
||||
class FaceDetectionOutput(TypedDict):
|
||||
boxes: npt.NDArray[np.float32]
|
||||
scores: npt.NDArray[np.float32]
|
||||
landmarks: npt.NDArray[np.float32]
|
||||
|
||||
|
||||
class DetectedFace(TypedDict):
|
||||
boundingBox: BoundingBox
|
||||
embedding: str
|
||||
score: float
|
||||
|
||||
|
||||
FacialRecognitionOutput = list[DetectedFace]
|
||||
|
||||
|
||||
class PipelineEntry(TypedDict):
|
||||
modelName: str
|
||||
options: dict[str, Any]
|
||||
|
||||
|
||||
PipelineRequest = dict[ModelTask, dict[ModelType, PipelineEntry]]
|
||||
|
||||
|
||||
class InferenceEntry(TypedDict):
|
||||
name: str
|
||||
task: ModelTask
|
||||
type: ModelType
|
||||
options: dict[str, Any]
|
||||
|
||||
|
||||
InferenceEntries = tuple[list[InferenceEntry], list[InferenceEntry]]
|
||||
|
||||
|
||||
InferenceResponse = dict[ModelTask | Literal["imageHeight"] | Literal["imageWidth"], Any]
|
||||
|
||||
|
||||
def has_profiling(obj: Any) -> TypeGuard[HasProfiling]:
|
||||
return hasattr(obj, "profiling") and isinstance(obj.profiling, dict)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
0
machine-learning/immich_ml/sessions/__init__.py
Normal file
0
machine-learning/immich_ml/sessions/__init__.py
Normal file
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
58
machine-learning/immich_ml/sessions/ann/__init__.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from .loader import Ann
|
||||
|
||||
|
||||
class AnnSession:
|
||||
"""
|
||||
Wrapper for ANN to be drop-in replacement for ONNX session.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path: Path, cache_dir: Path = settings.cache_folder) -> None:
|
||||
self.model_path = model_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ann = Ann(tuning_level=settings.ann_tuning_level, tuning_file=(cache_dir / "gpu-tuning.ann").as_posix())
|
||||
|
||||
log.info("Loading ANN model %s ...", model_path)
|
||||
self.model = self.ann.load(
|
||||
model_path.as_posix(),
|
||||
cached_network_path=model_path.with_suffix(".anncache").as_posix(),
|
||||
fp16=settings.ann_fp16_turbo,
|
||||
)
|
||||
log.info("Loaded ANN model with ID %d", self.model)
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.ann.unload(self.model)
|
||||
log.info("Unloaded ANN model %d", self.model)
|
||||
self.ann.destroy()
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.input_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
shapes = self.ann.output_shapes[self.model]
|
||||
return [AnnNode(None, s) for s in shapes]
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
inputs: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
return self.ann.execute(self.model, inputs)
|
||||
|
||||
|
||||
class AnnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
||||
169
machine-learning/immich_ml/sessions/ann/loader.py
Normal file
169
machine-learning/immich_ml/sessions/ann/loader.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ctypes import CDLL, Array, c_bool, c_char_p, c_int, c_ulong, c_void_p
|
||||
from os.path import exists
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log
|
||||
|
||||
try:
|
||||
CDLL("libmali.so") # fail if libmali.so is not mounted into container
|
||||
libann = CDLL("libann.so")
|
||||
libann.init.argtypes = c_int, c_int, c_char_p
|
||||
libann.init.restype = c_void_p
|
||||
libann.load.argtypes = c_void_p, c_char_p, c_bool, c_bool, c_bool, c_char_p
|
||||
libann.load.restype = c_int
|
||||
libann.execute.argtypes = c_void_p, c_int, Array[c_void_p], Array[c_void_p]
|
||||
libann.unload.argtypes = c_void_p, c_int
|
||||
libann.destroy.argtypes = (c_void_p,)
|
||||
libann.shape.argtypes = c_void_p, c_int, c_bool, c_int
|
||||
libann.shape.restype = c_ulong
|
||||
libann.tensors.argtypes = c_void_p, c_int, c_bool
|
||||
libann.tensors.restype = c_int
|
||||
is_available = True
|
||||
except OSError as e:
|
||||
log.debug("Could not load ANN shared libraries, using ONNX: %s", e)
|
||||
is_available = False
|
||||
|
||||
T = TypeVar("T", covariant=True)
|
||||
|
||||
|
||||
class Newable(Protocol[T]):
|
||||
def new(self) -> None: ...
|
||||
|
||||
|
||||
class _Singleton(type, Newable[T]):
|
||||
_instances: dict[_Singleton[T], Newable[T]] = {}
|
||||
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> Newable[T]:
|
||||
if cls not in cls._instances:
|
||||
obj: Newable[T] = super(_Singleton, cls).__call__(*args, **kwargs)
|
||||
cls._instances[cls] = obj
|
||||
else:
|
||||
obj = cls._instances[cls]
|
||||
obj.new()
|
||||
return obj
|
||||
|
||||
|
||||
class Ann(metaclass=_Singleton):
|
||||
def __init__(self, log_level: int = 3, tuning_level: int = 1, tuning_file: str | None = None) -> None:
|
||||
if not is_available:
|
||||
raise RuntimeError("libann is not available!")
|
||||
if tuning_level == 0 and tuning_file is None:
|
||||
raise ValueError("tuning_level == 0 reads existing tuning information and requires a tuning_file")
|
||||
if tuning_level < 0 or tuning_level > 3:
|
||||
raise ValueError("tuning_level must be 0 (load from tuning_file), 1, 2 or 3.")
|
||||
if log_level < 0 or log_level > 5:
|
||||
raise ValueError("log_level must be 0 (trace), 1 (debug), 2 (info), 3 (warning), 4 (error) or 5 (fatal)")
|
||||
self.log_level = log_level
|
||||
self.tuning_level = tuning_level
|
||||
self.tuning_file = tuning_file
|
||||
self.output_shapes: dict[int, tuple[tuple[int], ...]] = {}
|
||||
self.input_shapes: dict[int, tuple[tuple[int], ...]] = {}
|
||||
self.ann: int | None = None
|
||||
self.new()
|
||||
|
||||
if self.tuning_file is not None:
|
||||
# make sure tuning file exists (without clearing contents)
|
||||
# once filled, the tuning file reduces the cost/time of the first
|
||||
# inference after model load by 10s of seconds
|
||||
open(self.tuning_file, "a").close()
|
||||
|
||||
def new(self) -> None:
|
||||
if self.ann is None:
|
||||
self.ann = libann.init(
|
||||
self.log_level,
|
||||
self.tuning_level,
|
||||
self.tuning_file.encode() if self.tuning_file is not None else None,
|
||||
)
|
||||
self.ref_count = 0
|
||||
|
||||
self.ref_count += 1
|
||||
|
||||
def destroy(self) -> None:
|
||||
self.ref_count -= 1
|
||||
if self.ref_count <= 0 and self.ann is not None:
|
||||
libann.destroy(self.ann)
|
||||
self.ann = None
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.ann is not None:
|
||||
libann.destroy(self.ann)
|
||||
self.ann = None
|
||||
|
||||
def load(
|
||||
self,
|
||||
model_path: str,
|
||||
fast_math: bool = True,
|
||||
fp16: bool = False,
|
||||
cached_network_path: str | None = None,
|
||||
) -> int:
|
||||
if not model_path.endswith((".armnn", ".tflite", ".onnx")):
|
||||
raise ValueError("model_path must be a file with extension .armnn, .tflite or .onnx")
|
||||
if not exists(model_path):
|
||||
raise ValueError("model_path must point to an existing file!")
|
||||
|
||||
save_cached_network = False
|
||||
if cached_network_path is not None and not exists(cached_network_path):
|
||||
save_cached_network = True
|
||||
# create empty model cache file
|
||||
open(cached_network_path, "a").close()
|
||||
|
||||
net_id: int = libann.load(
|
||||
self.ann,
|
||||
model_path.encode(),
|
||||
fast_math,
|
||||
fp16,
|
||||
save_cached_network,
|
||||
cached_network_path.encode() if cached_network_path is not None else None,
|
||||
)
|
||||
if net_id < 0:
|
||||
raise ValueError("Cannot load model!")
|
||||
|
||||
self.input_shapes[net_id] = tuple(
|
||||
self.shape(net_id, input=True, index=i) for i in range(self.tensors(net_id, input=True))
|
||||
)
|
||||
self.output_shapes[net_id] = tuple(
|
||||
self.shape(net_id, input=False, index=i) for i in range(self.tensors(net_id, input=False))
|
||||
)
|
||||
return net_id
|
||||
|
||||
def unload(self, network_id: int) -> None:
|
||||
libann.unload(self.ann, network_id)
|
||||
del self.output_shapes[network_id]
|
||||
|
||||
def execute(self, network_id: int, input_tensors: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
if not isinstance(input_tensors, list):
|
||||
raise ValueError("input_tensors needs to be a list!")
|
||||
net_input_shapes = self.input_shapes[network_id]
|
||||
if len(input_tensors) != len(net_input_shapes):
|
||||
raise ValueError(f"input_tensors lengths {len(input_tensors)} != network inputs {len(net_input_shapes)}")
|
||||
for net_input_shape, input_tensor in zip(net_input_shapes, input_tensors):
|
||||
if net_input_shape != input_tensor.shape:
|
||||
raise ValueError(f"input_tensor shape {input_tensor.shape} != network input shape {net_input_shape}")
|
||||
if not input_tensor.flags.c_contiguous:
|
||||
raise ValueError("input_tensors must be c_contiguous numpy ndarrays")
|
||||
output_tensors: list[NDArray[np.float32]] = [
|
||||
np.ndarray(s, dtype=np.float32) for s in self.output_shapes[network_id]
|
||||
]
|
||||
input_type = c_void_p * len(input_tensors)
|
||||
inputs = input_type(*[t.ctypes.data_as(c_void_p) for t in input_tensors])
|
||||
output_type = c_void_p * len(output_tensors)
|
||||
outputs = output_type(*[t.ctypes.data_as(c_void_p) for t in output_tensors])
|
||||
libann.execute(self.ann, network_id, inputs, outputs)
|
||||
return output_tensors
|
||||
|
||||
def shape(self, network_id: int, input: bool = False, index: int = 0) -> tuple[int]:
|
||||
s = libann.shape(self.ann, network_id, input, index)
|
||||
a = []
|
||||
while s != 0:
|
||||
a.append(s & 0xFFFF)
|
||||
s >>= 16
|
||||
return tuple(a)
|
||||
|
||||
def tensors(self, network_id: int, input: bool = False) -> int:
|
||||
tensors: int = libann.tensors(self.ann, network_id, input)
|
||||
return tensors
|
||||
147
machine-learning/immich_ml/sessions/ort.py
Normal file
147
machine-learning/immich_ml/sessions/ort.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.models.constants import SUPPORTED_PROVIDERS
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from ..config import log, settings
|
||||
|
||||
|
||||
class OrtSession:
|
||||
session: ort.InferenceSession
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Path | str,
|
||||
providers: list[str] | None = None,
|
||||
provider_options: list[dict[str, Any]] | None = None,
|
||||
sess_options: ort.SessionOptions | None = None,
|
||||
):
|
||||
self.model_path = Path(model_path)
|
||||
self.providers = providers if providers is not None else self._providers_default
|
||||
self.provider_options = provider_options if provider_options is not None else self._provider_options_default
|
||||
self.sess_options = sess_options if sess_options is not None else self._sess_options_default
|
||||
self.session = ort.InferenceSession(
|
||||
self.model_path.as_posix(),
|
||||
providers=self.providers,
|
||||
provider_options=self.provider_options,
|
||||
sess_options=self.sess_options,
|
||||
)
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
inputs: list[SessionNode] = self.session.get_inputs()
|
||||
return inputs
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
outputs: list[SessionNode] = self.session.get_outputs()
|
||||
return outputs
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = self.session.run(output_names, input_feed, run_options)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def providers(self) -> list[str]:
|
||||
return self._providers
|
||||
|
||||
@providers.setter
|
||||
def providers(self, providers: list[str]) -> None:
|
||||
log.info(f"Setting execution providers to {providers}, in descending order of preference")
|
||||
self._providers = providers
|
||||
|
||||
@property
|
||||
def _providers_default(self) -> list[str]:
|
||||
available_providers = set(ort.get_available_providers())
|
||||
log.debug(f"Available ORT providers: {available_providers}")
|
||||
if (openvino := "OpenVINOExecutionProvider") in available_providers:
|
||||
device_ids: list[str] = ort.capi._pybind_state.get_available_openvino_device_ids()
|
||||
log.debug(f"Available OpenVINO devices: {device_ids}")
|
||||
|
||||
gpu_devices = [device_id for device_id in device_ids if device_id.startswith("GPU")]
|
||||
if not gpu_devices:
|
||||
log.warning("No GPU device found in OpenVINO. Falling back to CPU.")
|
||||
available_providers.remove(openvino)
|
||||
return [provider for provider in SUPPORTED_PROVIDERS if provider in available_providers]
|
||||
|
||||
@property
|
||||
def provider_options(self) -> list[dict[str, Any]]:
|
||||
return self._provider_options
|
||||
|
||||
@provider_options.setter
|
||||
def provider_options(self, provider_options: list[dict[str, Any]]) -> None:
|
||||
log.debug(f"Setting execution provider options to {provider_options}")
|
||||
self._provider_options = provider_options
|
||||
|
||||
@property
|
||||
def _provider_options_default(self) -> list[dict[str, Any]]:
|
||||
provider_options = []
|
||||
for provider in self.providers:
|
||||
match provider:
|
||||
case "CPUExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested"}
|
||||
case "CUDAExecutionProvider" | "ROCMExecutionProvider":
|
||||
options = {"arena_extend_strategy": "kSameAsRequested", "device_id": settings.device_id}
|
||||
case "OpenVINOExecutionProvider":
|
||||
openvino_dir = self.model_path.parent / "openvino"
|
||||
device = f"GPU.{settings.device_id}"
|
||||
options = {
|
||||
"device_type": device,
|
||||
"precision": settings.openvino_precision.value,
|
||||
"cache_dir": openvino_dir.as_posix(),
|
||||
}
|
||||
case "CoreMLExecutionProvider":
|
||||
options = {
|
||||
"ModelFormat": "MLProgram",
|
||||
"MLComputeUnits": "ALL",
|
||||
"SpecializationStrategy": "FastPrediction",
|
||||
"AllowLowPrecisionAccumulationOnGPU": "1",
|
||||
"ModelCacheDirectory": (self.model_path.parent / "coreml").as_posix(),
|
||||
}
|
||||
case _:
|
||||
options = {}
|
||||
provider_options.append(options)
|
||||
return provider_options
|
||||
|
||||
@property
|
||||
def sess_options(self) -> ort.SessionOptions:
|
||||
return self._sess_options
|
||||
|
||||
@sess_options.setter
|
||||
def sess_options(self, sess_options: ort.SessionOptions) -> None:
|
||||
log.debug(f"Setting execution_mode to {sess_options.execution_mode.name}")
|
||||
log.debug(f"Setting inter_op_num_threads to {sess_options.inter_op_num_threads}")
|
||||
log.debug(f"Setting intra_op_num_threads to {sess_options.intra_op_num_threads}")
|
||||
self._sess_options = sess_options
|
||||
|
||||
@property
|
||||
def _sess_options_default(self) -> ort.SessionOptions:
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.enable_cpu_mem_arena = settings.model_arena
|
||||
|
||||
# avoid thread contention between models
|
||||
if settings.model_inter_op_threads > 0:
|
||||
sess_options.inter_op_num_threads = settings.model_inter_op_threads
|
||||
# these defaults work well for CPU, but bottleneck GPU
|
||||
elif settings.model_inter_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
||||
sess_options.inter_op_num_threads = 1
|
||||
|
||||
if settings.model_intra_op_threads > 0:
|
||||
sess_options.intra_op_num_threads = settings.model_intra_op_threads
|
||||
elif settings.model_intra_op_threads == 0 and self.providers == ["CPUExecutionProvider"]:
|
||||
sess_options.intra_op_num_threads = 2
|
||||
|
||||
if sess_options.inter_op_num_threads > 1:
|
||||
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
|
||||
|
||||
return sess_options
|
||||
76
machine-learning/immich_ml/sessions/rknn/__init__.py
Normal file
76
machine-learning/immich_ml/sessions/rknn/__init__.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log, settings
|
||||
from immich_ml.schemas import SessionNode
|
||||
|
||||
from .rknnpool import RknnPoolExecutor, is_available, soc_name
|
||||
|
||||
is_available = is_available and settings.rknn
|
||||
model_prefix = Path("rknpu") / soc_name if is_available and soc_name is not None else None
|
||||
|
||||
|
||||
def run_inference(rknn_lite: Any, input: list[NDArray[np.float32]]) -> list[NDArray[np.float32]]:
|
||||
outputs: list[NDArray[np.float32]] = rknn_lite.inference(inputs=input, data_format="nchw")
|
||||
return outputs
|
||||
|
||||
|
||||
input_output_mapping: dict[str, dict[str, Any]] = {
|
||||
"detection": {
|
||||
"input": {"norm_tensor:0": (1, 3, 640, 640)},
|
||||
"output": {
|
||||
"norm_tensor:1": (12800, 1),
|
||||
"norm_tensor:2": (3200, 1),
|
||||
"norm_tensor:3": (800, 1),
|
||||
"norm_tensor:4": (12800, 4),
|
||||
"norm_tensor:5": (3200, 4),
|
||||
"norm_tensor:6": (800, 4),
|
||||
"norm_tensor:7": (12800, 10),
|
||||
"norm_tensor:8": (3200, 10),
|
||||
"norm_tensor:9": (800, 10),
|
||||
},
|
||||
},
|
||||
"recognition": {"input": {"norm_tensor:0": (1, 3, 112, 112)}, "output": {"norm_tensor:1": (1, 512)}},
|
||||
}
|
||||
|
||||
|
||||
class RknnSession:
|
||||
def __init__(self, model_path: Path) -> None:
|
||||
self.model_type = "detection" if "detection" in model_path.parts else "recognition"
|
||||
self.tpe = settings.rknn_threads
|
||||
|
||||
log.info(f"Loading RKNN model from {model_path} with {self.tpe} threads.")
|
||||
self.rknnpool = RknnPoolExecutor(model_path=model_path.as_posix(), tpes=self.tpe, func=run_inference)
|
||||
log.info(f"Loaded RKNN model from {model_path} with {self.tpe} threads.")
|
||||
|
||||
def get_inputs(self) -> list[SessionNode]:
|
||||
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["input"].items()]
|
||||
|
||||
def get_outputs(self) -> list[SessionNode]:
|
||||
return [RknnNode(name=k, shape=v) for k, v in input_output_mapping[self.model_type]["output"].items()]
|
||||
|
||||
def run(
|
||||
self,
|
||||
output_names: list[str] | None,
|
||||
input_feed: dict[str, NDArray[np.float32]] | dict[str, NDArray[np.int32]],
|
||||
run_options: Any = None,
|
||||
) -> list[NDArray[np.float32]]:
|
||||
input_data: list[NDArray[np.float32]] = [np.ascontiguousarray(v) for v in input_feed.values()]
|
||||
self.rknnpool.put(input_data)
|
||||
res = self.rknnpool.get()
|
||||
if res is None:
|
||||
raise RuntimeError("RKNN inference failed!")
|
||||
return res
|
||||
|
||||
|
||||
class RknnNode(NamedTuple):
|
||||
name: str | None
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
__all__ = ["RknnSession", "RknnNode", "is_available", "soc_name", "model_prefix"]
|
||||
91
machine-learning/immich_ml/sessions/rknn/rknnpool.py
Normal file
91
machine-learning/immich_ml/sessions/rknn/rknnpool.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
# This code is from leafqycc/rknn-multi-threaded
|
||||
# Following Apache License 2.0
|
||||
|
||||
import logging
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from immich_ml.config import log
|
||||
from immich_ml.models.constants import RKNN_COREMASK_SUPPORTED_SOCS, RKNN_SUPPORTED_SOCS
|
||||
|
||||
|
||||
def get_soc(device_tree_path: Path | str) -> str | None:
|
||||
try:
|
||||
with Path(device_tree_path).open() as f:
|
||||
device_compatible_str = f.read()
|
||||
for soc in RKNN_SUPPORTED_SOCS:
|
||||
if soc in device_compatible_str:
|
||||
return soc
|
||||
log.warning("Device is not supported for RKNN")
|
||||
except OSError as e:
|
||||
log.warning(f"Could not read {device_tree_path}. Reason: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
soc_name = None
|
||||
is_available = False
|
||||
try:
|
||||
from rknnlite.api import RKNNLite
|
||||
|
||||
soc_name = get_soc("/proc/device-tree/compatible")
|
||||
is_available = soc_name is not None
|
||||
except ImportError:
|
||||
log.debug("RKNN is not available")
|
||||
|
||||
|
||||
def init_rknn(model_path: str) -> "RKNNLite":
|
||||
if not is_available:
|
||||
raise RuntimeError("rknn is not available!")
|
||||
rknn_lite = RKNNLite()
|
||||
rknn_lite.rknn_log.logger.setLevel(logging.ERROR)
|
||||
ret = rknn_lite.load_rknn(model_path)
|
||||
if ret != 0:
|
||||
raise RuntimeError("Failed to load RKNN model")
|
||||
|
||||
if soc_name in RKNN_COREMASK_SUPPORTED_SOCS:
|
||||
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_AUTO)
|
||||
else:
|
||||
ret = rknn_lite.init_runtime() # Please do not set this parameter on other platforms.
|
||||
|
||||
if ret != 0:
|
||||
raise RuntimeError("Failed to initialize RKNN runtime environment")
|
||||
|
||||
return rknn_lite
|
||||
|
||||
|
||||
class RknnPoolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
tpes: int,
|
||||
func: Callable[["RKNNLite", list[NDArray[np.float32]]], list[NDArray[np.float32]]],
|
||||
) -> None:
|
||||
self.tpes = tpes
|
||||
self.queue: Queue[Future[list[NDArray[np.float32]]]] = Queue()
|
||||
self.rknn_pool = [init_rknn(model_path) for _ in range(tpes)]
|
||||
self.pool = ThreadPoolExecutor(max_workers=tpes)
|
||||
self.func = func
|
||||
self.num = 0
|
||||
|
||||
def put(self, inputs: list[NDArray[np.float32]]) -> None:
|
||||
self.queue.put(self.pool.submit(self.func, self.rknn_pool[self.num % self.tpes], inputs))
|
||||
self.num += 1
|
||||
|
||||
def get(self) -> list[NDArray[np.float32]] | None:
|
||||
if self.queue.empty():
|
||||
return None
|
||||
fut = self.queue.get()
|
||||
return fut.result()
|
||||
|
||||
def release(self) -> None:
|
||||
self.pool.shutdown()
|
||||
for rknn_lite in self.rknn_pool:
|
||||
rknn_lite.release()
|
||||
|
||||
def __del__(self) -> None:
|
||||
self.release()
|
||||
Loading…
Add table
Add a link
Reference in a new issue