When more than a single request was coming in at a time, the server would crash when trying to write the output file. Each worker had a single detector with one DN network, so if more than <WORKER#> requests came in at once, one network would be used multiple times--and it's not threadsafe. I've added a queue of nets to pull from, so each request has exclusive access to a specific network.
126 lines
3.8 KiB
Python
126 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
|
|
from detector import Detector
|
|
from fastapi import FastAPI, HTTPException, Request, UploadFile, File
|
|
from fastapi.responses import FileResponse
|
|
from os import listdir, remove
|
|
from os.path import getatime, splitext, basename, isfile
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
from time import sleep
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import tempfile
|
|
|
|
|
|
def get_envvar(name):
|
|
return os.environ.get(name)
|
|
|
|
def get_envvar_or_fail(name):
|
|
result = get_envvar(name)
|
|
if result:
|
|
return result
|
|
else:
|
|
raise EnvironmentError('Missing required environment variable: ' + name)
|
|
|
|
def to_int(input_int):
|
|
if input_int:
|
|
return int(input_int)
|
|
else:
|
|
return None
|
|
|
|
yolo_config = get_envvar_or_fail('OBJECTIFIER_YOLOV3_CONFIG')
|
|
yolo_weights = get_envvar_or_fail('OBJECTIFIER_YOLOV3_WEIGHTS')
|
|
yolo_labels_file = get_envvar_or_fail('OBJECTIFIER_YOLOV3_LABELS')
|
|
buffer_size = to_int(get_envvar('OBJECTIFIER_BUFFER_SIZE')) or 524288
|
|
max_file_age = to_int(get_envvar('OBJECTIFIER_CLEANUP_MAX_AGE'))
|
|
file_cleanup_delay = to_int(get_envvar('OBJECTIFIER_CLEANUP_DELAY'))
|
|
detection_timeout = to_int(get_envvar('OBJECTIFIER_TIMEOUT')) or 5
|
|
pool_size = to_int(get_envvar('OBJECTIFIER_POOL_SIZE')) or 10
|
|
|
|
yolo_labels = []
|
|
with open(yolo_labels_file, "r") as f:
|
|
yolo_labels = [line.strip() for line in f.readlines()]
|
|
|
|
incoming_dir = Path(get_envvar_or_fail('CACHE_DIRECTORY'))
|
|
outgoing_dir = Path(get_envvar_or_fail('STATE_DIRECTORY'))
|
|
|
|
detector = Detector(
|
|
weights=yolo_weights,
|
|
cfg=yolo_config,
|
|
classes=yolo_labels,
|
|
tempdir=outgoing_dir,
|
|
pool_size=pool_size)
|
|
|
|
app = FastAPI()
|
|
|
|
analyzed_images = {}
|
|
|
|
img_formats = [ "png", "jpg", "gif", "bmp" ]
|
|
|
|
def cleanup_old_files(path, extensions, max_age):
|
|
for filename in listdir(path):
|
|
if (splitext(filename) in extensions) and (getatime(filename) - time.time() > max_age):
|
|
print("removing old output file: " + filename)
|
|
remove(filename)
|
|
|
|
def run_cleanup_thread(path, extensions, age, delay):
|
|
while True:
|
|
cleanup_old_files(path, extensions, age)
|
|
sleep(delay)
|
|
|
|
cleanup_thread = Thread(
|
|
target=run_cleanup_thread,
|
|
args=(outgoing_dir, img_formats, max_file_age, file_cleanup_delay))
|
|
|
|
cleanup_thread.daemon = True
|
|
|
|
cleanup_thread.start()
|
|
|
|
def detection_to_dict(d):
|
|
return {
|
|
"label": d.label,
|
|
"confidence": d.confidence,
|
|
"box": {
|
|
"x": d.box[0],
|
|
"y": d.box[1],
|
|
"width": d.box[2],
|
|
"height": d.box[3],
|
|
},
|
|
}
|
|
|
|
def result_to_dict(res, base_url):
|
|
return {
|
|
"labels": [detection.label for detection in res.detections],
|
|
"detections": [detection_to_dict(detection) for detection in res.detections],
|
|
"output": base_url + basename(res.outfile),
|
|
}
|
|
|
|
@app.post("/images")
|
|
def analyze_image(request: Request, image: UploadFile):
|
|
base_url = re.sub(r'\/images\/?$', '/analyzed_images/', str(request.url))
|
|
infile = incoming_dir / image.filename
|
|
file_hash = hashlib.sha256()
|
|
with open(infile, "wb") as f:
|
|
chunk = image.file.read(buffer_size)
|
|
while chunk:
|
|
file_hash.update(chunk)
|
|
print ("writing " + str(buffer_size) + " bytes")
|
|
f.write(chunk)
|
|
chunk = image.file.read(buffer_size)
|
|
result = detector.detect_objects(
|
|
infile,
|
|
str(outgoing_dir / (file_hash.hexdigest() + ".png")),
|
|
detection_timeout)
|
|
remove(infile)
|
|
return result_to_dict(result, base_url)
|
|
|
|
@app.get("/analyzed_images/{image_name}", response_class=FileResponse)
|
|
def get_analyzed_image(image_name: str):
|
|
filename = outgoing_dir / image_name
|
|
if isfile(filename):
|
|
return str(filename)
|
|
else:
|
|
raise HTTPException(status_code=404, detail="file not found: " + str(filename))
|