#!/usr/bin/env python3 from detector import Detector from fastapi import FastAPI, HTTPException, Request, UploadFile, File from fastapi.responses import FileResponse from os import listdir, remove from os.path import getctime, splitext, basename, isfile from pathlib import Path from threading import Thread from time import sleep import hashlib import os import re import tempfile def get_envvar(name): return os.environ.get(name) def get_envvar_or_fail(name): result = get_envvar(name) if result: return result else: raise EnvironmentError('Missing required environment variable: ' + name) def to_int(input_int): if input_int: return int(input_int) else: return None yolo_config = get_envvar_or_fail('OBJECTIFIER_YOLOV3_CONFIG') yolo_weights = get_envvar_or_fail('OBJECTIFIER_YOLOV3_WEIGHTS') yolo_labels_file = get_envvar_or_fail('OBJECTIFIER_YOLOV3_LABELS') buffer_size = to_int(get_envvar('OBJECTIFIER_BUFFER_SIZE')) or 524288 max_file_age = to_int(get_envvar('OBJECTIFIER_CLEANUP_MAX_AGE')) file_cleanup_delay = to_int(get_envvar('OBJECTIFIER_CLEANUP_DELAY')) or 60 detection_timeout = to_int(get_envvar('OBJECTIFIER_TIMEOUT')) or 5 pool_size = to_int(get_envvar('OBJECTIFIER_POOL_SIZE')) or 10 yolo_labels = [] with open(yolo_labels_file, "r") as f: yolo_labels = [line.strip() for line in f.readlines()] incoming_dir = Path(get_envvar_or_fail('CACHE_DIRECTORY')) outgoing_dir = Path(get_envvar_or_fail('STATE_DIRECTORY')) detector = Detector( weights=yolo_weights, cfg=yolo_config, classes=yolo_labels, tempdir=outgoing_dir, pool_size=pool_size) app = FastAPI() analyzed_images = {} img_formats = [ "png", "jpg", "gif", "bmp" ] def cleanup_old_files(path, extensions, max_age): for filename in listdir(path): if (splitext(filename) in extensions) and (getctime(filename) - time.time() > max_age): print("removing old output file: " + filename) remove(filename) def run_cleanup_thread(path, extensions, age, delay): while True: print("running cleanup thread") cleanup_old_files(path, extensions, age) sleep(delay) cleanup_thread = Thread( target=run_cleanup_thread, args=(outgoing_dir, img_formats, max_file_age, file_cleanup_delay)) cleanup_thread.daemon = True cleanup_thread.start() def detection_to_dict(d): return { "label": d.label, "confidence": d.confidence, "box": { "x": d.box[0], "y": d.box[1], "width": d.box[2], "height": d.box[3], }, } def result_to_dict(res, base_url): return { "labels": [detection.label for detection in res.detections], "detections": [detection_to_dict(detection) for detection in res.detections], "output": base_url + basename(res.outfile), } @app.post("/images") def analyze_image(request: Request, image: UploadFile): base_url = re.sub(r'\/images\/?$', '/analyzed_images/', str(request.url)) infile = incoming_dir / image.filename file_hash = hashlib.sha256() with open(infile, "wb") as f: chunk = image.file.read(buffer_size) while chunk: file_hash.update(chunk) print ("writing " + str(buffer_size) + " bytes") f.write(chunk) chunk = image.file.read(buffer_size) result = detector.detect_objects( infile, detection_timeout, str(outgoing_dir / (file_hash.hexdigest() + ".png"))) remove(infile) return result_to_dict(result, base_url) @app.get("/analyzed_images/{image_name}", response_class=FileResponse) def get_analyzed_image(image_name: str): filename = outgoing_dir / image_name if isfile(filename): return str(filename) else: raise HTTPException(status_code=404, detail="file not found: " + str(filename))