-
-
Notifications
You must be signed in to change notification settings - Fork 134
/
Copy pathmain.py
280 lines (246 loc) · 9.66 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import shutil
import threading
from functools import wraps
from io import BytesIO
from flask import Flask, abort, jsonify, redirect, request, send_file, session, url_for
from database import get_image_path_by_id, is_video_exist, get_pexels_video_count
from init import *
from models import DatabaseSession, DatabaseSessionPexelsVideo
from process_assets import match_text_and_image, process_image, process_text
from scan import Scanner
from search import (
clean_cache,
search_image_by_image,
search_image_by_text_path_time,
search_video_by_image,
search_video_by_text_path_time,
search_pexels_video_by_text,
)
from utils import crop_video, get_hash, resize_image_with_aspect_ratio
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.secret_key = "https://github.com/chn-lee-yumi/MaterialSearch"
scanner = Scanner()
def init():
"""
清理和创建临时文件夹,初始化扫描线程(包括数据库初始化),根据AUTO_SCAN决定是否开启自动扫描线程
"""
global scanner
# 检查ASSETS_PATH是否存在
for path in ASSETS_PATH:
if not os.path.isdir(path):
logger.warning(f"ASSETS_PATH检查:路径 {path} 不存在!请检查输入的路径是否正确!")
# 删除临时目录中所有文件
shutil.rmtree(f'{TEMP_PATH}', ignore_errors=True)
os.makedirs(f'{TEMP_PATH}/upload')
os.makedirs(f'{TEMP_PATH}/video_clips')
# 初始化扫描线程
scanner.init()
if AUTO_SCAN:
auto_scan_thread = threading.Thread(target=scanner.auto_scan, args=())
auto_scan_thread.start()
def login_required(view_func):
"""
装饰器函数,用于控制需要登录认证的视图
"""
@wraps(view_func)
def wrapper(*args, **kwargs):
# 检查登录开关状态
if ENABLE_LOGIN:
# 如果开关已启用,则进行登录认证检查
if "username" not in session:
# 如果用户未登录,则重定向到登录页面
return redirect(url_for("login"))
# 调用原始的视图函数
return view_func(*args, **kwargs)
return wrapper
@app.route("/", methods=["GET"])
@login_required
def index_page():
"""主页"""
return app.send_static_file("index.html")
@app.route("/login", methods=["GET", "POST"])
def login():
"""登录"""
if request.method == "POST":
# 获取用户IP地址
ip_addr = request.environ.get("HTTP_X_FORWARDED_FOR", request.remote_addr)
# 获取表单数据
username = request.form["username"]
password = request.form["password"]
# 简单的验证逻辑
if username == USERNAME and password == PASSWORD:
# 登录成功,将用户名保存到会话中
logger.info(f"用户登录成功 {ip_addr}")
session["username"] = username
return redirect(url_for("index_page"))
# 登录失败,重定向到登录页面
logger.info(f"用户登录失败 {ip_addr}")
return redirect(url_for("login"))
return app.send_static_file("login.html")
@app.route("/logout", methods=["GET", "POST"])
def logout():
"""登出"""
# 清除会话数据
session.clear()
return redirect(url_for("login"))
@app.route("/api/scan", methods=["GET"])
@login_required
def api_scan():
"""开始扫描"""
global scanner
if not scanner.is_scanning:
scan_thread = threading.Thread(target=scanner.scan, args=(False,))
scan_thread.start()
return jsonify({"status": "start scanning"})
return jsonify({"status": "already scanning"})
@app.route("/api/status", methods=["GET"])
@login_required
def api_status():
"""状态"""
global scanner
result = scanner.get_status()
with DatabaseSessionPexelsVideo() as session:
result["total_pexels_videos"] = get_pexels_video_count(session)
return jsonify(result)
@app.route("/api/clean_cache", methods=["GET", "POST"])
@login_required
def api_clean_cache():
"""
清缓存
:return: 204 No Content
"""
clean_cache()
return "", 204
@app.route("/api/match", methods=["POST"])
@login_required
def api_match():
"""
匹配文字对应的素材
:return: json格式的素材信息列表
"""
data = request.get_json()
top_n = int(data["top_n"])
search_type = data["search_type"]
positive_threshold = data["positive_threshold"]
negative_threshold = data["negative_threshold"]
image_threshold = data["image_threshold"]
img_id = data["img_id"]
path = data["path"]
start_time = data["start_time"]
end_time = data["end_time"]
upload_file_path = session.get('upload_file_path', '')
session['upload_file_path'] = ""
if search_type in (1, 3, 4):
if not upload_file_path or not os.path.exists(upload_file_path):
return "你没有上传文件!", 400
logger.debug(data)
# 进行匹配
if search_type == 0: # 文字搜图
results = search_image_by_text_path_time(data["positive"], data["negative"], positive_threshold, negative_threshold,
path, start_time, end_time)
elif search_type == 1: # 以图搜图
results = search_image_by_image(upload_file_path, image_threshold)
elif search_type == 2: # 文字搜视频
results = search_video_by_text_path_time(data["positive"], data["negative"], positive_threshold, negative_threshold,
path, start_time, end_time)
elif search_type == 3: # 以图搜视频
results = search_video_by_image(upload_file_path, image_threshold)
elif search_type == 4: # 图文相似度匹配
score = match_text_and_image(process_text(data["text"]), process_image(upload_file_path)) * 100
return jsonify({"score": "%.2f" % score})
elif search_type == 5: # 以图搜图(图片是数据库中的)
results = search_image_by_image(img_id, image_threshold)
elif search_type == 6: # 以图搜视频(图片是数据库中的)
results = search_video_by_image(img_id, image_threshold)
elif search_type == 9: # 文字搜pexels视频
results = search_pexels_video_by_text(data["positive"], positive_threshold)
else: # 空
logger.warning(f"search_type不正确:{search_type}")
abort(400)
return jsonify(results[:top_n])
@app.route("/api/get_image/<int:image_id>", methods=["GET"])
@login_required
def api_get_image(image_id):
"""
读取图片
:param image_id: int, 图片在数据库中的id
:return: 图片文件
"""
with DatabaseSession() as session:
path = get_image_path_by_id(session, image_id)
logger.debug(path)
# 静态图片压缩返回
if request.args.get("thumbnail") and os.path.splitext(path)[-1] != "gif":
image = resize_image_with_aspect_ratio(path, (640, 480), convert_rgb=True)
image_io = BytesIO()
image.save(image_io, 'JPEG', quality=60)
image_io.seek(0)
return send_file(image_io, mimetype='image/jpeg')
return send_file(path)
@app.route("/api/get_video/<video_path>", methods=["GET"])
@login_required
def api_get_video(video_path):
"""
读取视频
:param video_path: string, 经过base64.urlsafe_b64encode的字符串,解码后可以得到视频在服务器上的绝对路径
:return: 视频文件
"""
path = base64.urlsafe_b64decode(video_path).decode()
logger.debug(path)
with DatabaseSession() as session:
if not is_video_exist(session, path): # 如果路径不在数据库中,则返回404,防止任意文件读取攻击
abort(404)
return send_file(path)
@app.route(
"/api/download_video_clip/<video_path>/<int:start_time>/<int:end_time>",
methods=["GET"],
)
@login_required
def api_download_video_clip(video_path, start_time, end_time):
"""
下载视频片段
:param video_path: string, 经过base64.urlsafe_b64encode的字符串,解码后可以得到视频在服务器上的绝对路径
:param start_time: int, 视频开始秒数
:param end_time: int, 视频结束秒数
:return: 视频文件
"""
path = base64.urlsafe_b64decode(video_path).decode()
logger.debug(path)
with DatabaseSession() as session:
if not is_video_exist(session, path): # 如果路径不在数据库中,则返回404,防止任意文件读取攻击
abort(404)
# 根据VIDEO_EXTENSION_LENGTH调整时长
start_time -= VIDEO_EXTENSION_LENGTH
end_time += VIDEO_EXTENSION_LENGTH
if start_time < 0:
start_time = 0
# 调用ffmpeg截取视频片段
output_path = f"{TEMP_PATH}/video_clips/{start_time}_{end_time}_" + os.path.basename(path)
if not os.path.exists(output_path): # 如果存在说明已经剪过,直接返回,如果不存在则剪
crop_video(path, output_path, start_time, end_time)
return send_file(output_path)
@app.route("/api/upload", methods=["POST"])
@login_required
def api_upload():
"""
上传文件。首先删除旧的文件,保存新文件,计算hash,重命名文件。
:return: 200
"""
logger.debug(request.files)
# 删除旧文件
upload_file_path = session.get('upload_file_path', '')
if upload_file_path and os.path.exists(upload_file_path):
os.remove(upload_file_path)
# 保存文件
f = request.files["file"]
filehash = get_hash(f.stream)
upload_file_path = f"{TEMP_PATH}/upload/{filehash}"
f.save(upload_file_path)
session['upload_file_path'] = upload_file_path
return "file uploaded successfully"
if __name__ == "__main__":
init()
logging.getLogger('werkzeug').setLevel(LOG_LEVEL)
init2()
app.run(port=PORT, host=HOST, debug=FLASK_DEBUG)