objectifier/src/objectifier.py
niten ed962240b3 Use a pool of detectors.
When more than a single request was coming in at a time, the server
would crash when trying to write the output file. Each worker had a
single detector with one DN network, so if more than <WORKER#> requests
came in at once, one network would be used multiple times--and it's not
threadsafe.

I've added a queue of nets to pull from, so each request has exclusive
access to a specific network.
2023-03-15 12:31:01 -07:00

126 lines
3.8 KiB
Python

#!/usr/bin/env python3
from detector import Detector
from fastapi import FastAPI, HTTPException, Request, UploadFile, File
from fastapi.responses import FileResponse
from os import listdir, remove
from os.path import getatime, splitext, basename, isfile
from pathlib import Path
from threading import Thread
from time import sleep
import hashlib
import os
import re
import tempfile
def get_envvar(name):
return os.environ.get(name)
def get_envvar_or_fail(name):
result = get_envvar(name)
if result:
return result
else:
raise EnvironmentError('Missing required environment variable: ' + name)
def to_int(input_int):
if input_int:
return int(input_int)
else:
return None
yolo_config = get_envvar_or_fail('OBJECTIFIER_YOLOV3_CONFIG')
yolo_weights = get_envvar_or_fail('OBJECTIFIER_YOLOV3_WEIGHTS')
yolo_labels_file = get_envvar_or_fail('OBJECTIFIER_YOLOV3_LABELS')
buffer_size = to_int(get_envvar('OBJECTIFIER_BUFFER_SIZE')) or 524288
max_file_age = to_int(get_envvar('OBJECTIFIER_CLEANUP_MAX_AGE'))
file_cleanup_delay = to_int(get_envvar('OBJECTIFIER_CLEANUP_DELAY'))
detection_timeout = to_int(get_envvar('OBJECTIFIER_TIMEOUT')) or 5
pool_size = to_int(get_envvar('OBJECTIFIER_POOL_SIZE')) or 10
yolo_labels = []
with open(yolo_labels_file, "r") as f:
yolo_labels = [line.strip() for line in f.readlines()]
incoming_dir = Path(get_envvar_or_fail('CACHE_DIRECTORY'))
outgoing_dir = Path(get_envvar_or_fail('STATE_DIRECTORY'))
detector = Detector(
weights=yolo_weights,
cfg=yolo_config,
classes=yolo_labels,
tempdir=outgoing_dir,
pool_size=pool_size)
app = FastAPI()
analyzed_images = {}
img_formats = [ "png", "jpg", "gif", "bmp" ]
def cleanup_old_files(path, extensions, max_age):
for filename in listdir(path):
if (splitext(filename) in extensions) and (getatime(filename) - time.time() > max_age):
print("removing old output file: " + filename)
remove(filename)
def run_cleanup_thread(path, extensions, age, delay):
while True:
cleanup_old_files(path, extensions, age)
sleep(delay)
cleanup_thread = Thread(
target=run_cleanup_thread,
args=(outgoing_dir, img_formats, max_file_age, file_cleanup_delay))
cleanup_thread.daemon = True
cleanup_thread.start()
def detection_to_dict(d):
return {
"label": d.label,
"confidence": d.confidence,
"box": {
"x": d.box[0],
"y": d.box[1],
"width": d.box[2],
"height": d.box[3],
},
}
def result_to_dict(res, base_url):
return {
"labels": [detection.label for detection in res.detections],
"detections": [detection_to_dict(detection) for detection in res.detections],
"output": base_url + basename(res.outfile),
}
@app.post("/images")
def analyze_image(request: Request, image: UploadFile):
base_url = re.sub(r'\/images\/?$', '/analyzed_images/', str(request.url))
infile = incoming_dir / image.filename
file_hash = hashlib.sha256()
with open(infile, "wb") as f:
chunk = image.file.read(buffer_size)
while chunk:
file_hash.update(chunk)
print ("writing " + str(buffer_size) + " bytes")
f.write(chunk)
chunk = image.file.read(buffer_size)
result = detector.detect_objects(
infile,
str(outgoing_dir / (file_hash.hexdigest() + ".png")),
detection_timeout)
remove(infile)
return result_to_dict(result, base_url)
@app.get("/analyzed_images/{image_name}", response_class=FileResponse)
def get_analyzed_image(image_name: str):
filename = outgoing_dir / image_name
if isfile(filename):
return str(filename)
else:
raise HTTPException(status_code=404, detail="file not found: " + str(filename))