Commit 124fb3cf by 庄欣

test

parent d0b4651f
version: "3"
services:
searcher-server:
image: chenglong555/pic_search_demo:0.5.1
volumes:
- /home/code001/data/avatar:/tmp/pic1
environment:
DATA_PATH: /tmp/images-data
networks:
- image-searcher
ports:
- 5002:5000
searcher-client:
image: chenglong555/pic_search_demo_web:0.2.0
ports:
- 8050:80
depends_on:
- searcher-server
links:
- searcher-server
environment:
API_URL: "http://127.0.0.1:5002"
networks:
- image-searcher
milvus:
build:
context: ./milvus
dockerfile: Dockerfile
ports:
- 19530:19530
networks:
- image-searcher
networks:
image-searcher:
driver: bridge
\ No newline at end of file
...@@ -15,14 +15,18 @@ services: ...@@ -15,14 +15,18 @@ services:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
entrypoint: "/var/lib/milvus/bin/milvus_server -c /var/lib/milvus/conf/server_config.yaml" entrypoint: "/var/lib/milvus/bin/milvus_server -c /var/lib/milvus/conf/server_config.yaml"
server: server:
image: milvusbootcamp/pic-search-webserver:0.10.0 build:
context: ./webserver
dockerfile: ./Dockerfile
container_name: iserver
volumes: volumes:
- ./webserver:/app
- /home/code001/Pictures:/tmp/tains - /home/code001/Pictures:/tmp/tains
- ./data/models/:/app/data/models
environment: environment:
- DATA_PATH=/tmp/images-data #存放经训练的图片 - DATA_PATH=/tmp/images-data #存放经训练的图片
- MILVUS_HOST=milvus #milvus地址 - MILVUS_HOST=milvus #milvus地址
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
working_dir: /app/src
networks: networks:
- image-searcher - image-searcher
links: links:
...@@ -33,6 +37,7 @@ services: ...@@ -33,6 +37,7 @@ services:
- milvus - milvus
client: client:
image: milvusbootcamp/pic-search-webclient:0.2.0 image: milvusbootcamp/pic-search-webclient:0.2.0
container_name: iclient
ports: ports:
- 8050:80 - 8050:80
depends_on: depends_on:
......
From tensorflow/tensorflow
WORKDIR /app/src
COPY . /app
ENV TF_XLA_FLAGS --tf_xla_cpu_global_jit
RUN mkdir -p /root/.keras/models && mv /app/data/models/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5 /root/.keras/models/
RUN apt-get update && apt-get install python3-pip python3 -y
RUN pip3 install -r /app/requirements.txt -i https://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com
RUN mkdir -p /tmp/search-images
#CMD gunicorn --bind 0.0.0.0:5000 -w 2 app:app --preload
EXPOSE 5000
CMD python3 app.py
flask-cors
Keras==2.3.1
numpy==1.16.5
Pillow==7.1.0
pymilvus==0.2.13
diskcache
flask
flask_restful
gunicorn==20.0.0
tensorflow==1.15.4
futures==3.1.1
requests
\ No newline at end of file
import os
os.call("")
\ No newline at end of file
import os
import os.path as path
import logging
from common.config import DATA_PATH, DEFAULT_TABLE
from common.const import UPLOAD_PATH
from common.const import input_shape
from common.const import default_cache_dir
from service.train import do_train
from service.search import do_search
from service.count import do_count
from service.delete import do_delete
from service.theardpool import thread_runner
from preprocessor.vggnet import vgg_extract_feat
from indexer.index import milvus_client, create_table, insert_vectors, delete_table, search_vectors, create_index
from service.search import query_name_from_ids
from flask_cors import CORS
from flask import Flask, request, send_file, jsonify
from flask_restful import reqparse
from werkzeug.utils import secure_filename
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
import numpy as np
from numpy import linalg as LA
import tensorflow as tf
from tensorflow.python.keras.backend import set_session
from tensorflow.python.keras.models import load_model
from diskcache import Cache
import shutil
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.5
global sess
sess = tf.Session(config=config)
set_session(sess)
app = Flask(__name__)
ALLOWED_EXTENSIONS = set(['jpg', 'png'])
app.config['UPLOAD_FOLDER'] = UPLOAD_PATH
app.config['JSON_SORT_KEYS'] = False
CORS(app)
model = None
def load_model():
global graph
graph = tf.get_default_graph()
global model
model = VGG16(weights='imagenet',
input_shape=input_shape,
pooling='max',
include_top=False)
@app.route('/api/v1/train', methods=['POST'])
def do_train_api():
args = reqparse.RequestParser(). \
add_argument('Table', type=str). \
add_argument('File', type=str). \
parse_args()
table_name = args['Table']
file_path = args['File']
try:
thread_runner(1, do_train, table_name, file_path)
filenames = os.listdir(file_path)
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
for filename in filenames:
shutil.copy(file_path + '/' + filename, DATA_PATH)
return "Start"
except Exception as e:
return "Error with {}".format(e)
@app.route('/api/v1/delete', methods=['POST'])
def do_delete_api():
args = reqparse.RequestParser(). \
add_argument('Table', type=str). \
parse_args()
table_name = args['Table']
print("delete table.")
status = do_delete(table_name)
try:
shutil.rmtree(DATA_PATH)
except:
print("cannot remove", DATA_PATH)
return "{}".format(status)
@app.route('/api/v1/count', methods=['POST'])
def do_count_api():
args = reqparse.RequestParser(). \
add_argument('Table', type=str). \
parse_args()
table_name = args['Table']
rows = do_count(table_name)
return "{}".format(rows)
@app.route('/api/v1/process')
def thread_status_api():
cache = Cache(default_cache_dir)
return "current: {}, total: {}".format(cache['current'], cache['total'])
@app.route('/data/<image_name>')
def image_path(image_name):
file_name = DATA_PATH + '/' + image_name
if path.exists(file_name):
return send_file(file_name)
return "file not exist"
@app.route("/api/v1/train1", methods=['POST'])
def train1():
args = reqparse.RequestParser(). \
add_argument("Path", type=str). \
parse_args()
path = args['Path']
try:
thread_runner(1, do_train, table_name, file_path)
filenames = os.listdir(file_path)
if not os.path.exists(DATA_PATH):
os.mkdir(DATA_PATH)
for filename in filenames:
shutil.copy(file_path + '/' + filename, DATA_PATH)
return "ok"
except Exception as e:
return "Error with {}".format(e)
@app.route('/api/v1/search', methods=['POST'])
def do_search_api():
args = reqparse.RequestParser(). \
add_argument("Table", type=str). \
add_argument("Num", type=int, default=1). \
parse_args()
table_name = args['Table']
if not table_name:
table_name = DEFAULT_TABLE
top_k = args['Num']
file = request.files.get('file', "")
if not file:
return "no file data", 400
if not file.name:
return "need file name", 400
if file:
filename = secure_filename(file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
res_id,res_distance = do_search(table_name, file_path, top_k, model, graph, sess)
if isinstance(res_id, str):
return res_id
res_img = [request.url_root +"data/" + x for x in res_id]
res = dict(zip(res_img,res_distance))
res = sorted(res.items(),key=lambda item:item[1])
return jsonify(res), 200
return "not found", 400
if __name__ == "__main__":
load_model()
app.run(host="0.0.0.0")
import os
MILVUS_HOST = os.getenv("MILVUS_HOST", "127.0.0.1")
MILVUS_PORT = os.getenv("MILVUS_PORT", 19530)
VECTOR_DIMENSION = os.getenv("VECTOR_DIMENSION", 512)
DATA_PATH = os.getenv("DATA_PATH", "/data/jpegimages")
DEFAULT_TABLE = os.getenv("DEFAULT_TABLE", "milvus")
UPLOAD_PATH = "/tmp/search-images"
UPLOAD_PATH="/tmp/search-images"
default_indexer="milvus"
default_cache_dir="./tmp"
input_shape=(224,224,3)
import os
import numpy as np
from common.config import DATA_PATH as database_path
from encoder.utils import get_imlist
from preprocessor.vggnet import VGGNet
from diskcache import Cache
from common.const import default_cache_dir
def feature_extract(database_path, model):
cache = Cache(default_cache_dir)
feats = []
names = []
img_list = get_imlist(database_path)
model = model
for i, img_path in enumerate(img_list):
norm_feat = model.vgg_extract_feat(img_path)
img_name = os.path.split(img_path)[1]
feats.append(norm_feat)
names.append(img_name.encode())
current = i+1
total = len(img_list)
cache['current'] = current
cache['total'] = total
print ("extracting feature from image No. %d , %d images in total" %(current, total))
# feats = np.array(feats)
return feats, names
\ No newline at end of file
import os
def get_imlist(path):
return [os.path.join(path, f) for f in os.listdir(path) if (f.endswith('.jpg') or f.endswith('.png'))]
import logging as log
from milvus import Milvus, IndexType, MetricType, Status
from common.config import MILVUS_HOST, MILVUS_PORT, VECTOR_DIMENSION
def milvus_client():
try:
milvus = Milvus(host=MILVUS_HOST, port=MILVUS_PORT)
# status = milvus.connect(MILVUS_HOST, MILVUS_PORT)
return milvus
except Exception as e:
log.error(e)
def create_table(client, table_name=None, dimension=VECTOR_DIMENSION,
index_file_size=1024, metric_type=MetricType.L2):
table_param = {
'collection_name': table_name,
'dimension': dimension,
'index_file_size':index_file_size,
'metric_type': metric_type
}
try:
status = client.create_collection(table_param)
return status
except Exception as e:
log.error(e)
def insert_vectors(client, table_name, vectors):
if not client.has_collection(collection_name=table_name):
log.error("collection %s not exist", table_name)
return
try:
status, ids = client.insert(collection_name=table_name, records=vectors)
return status, ids
except Exception as e:
log.error(e)
def create_index(client, table_name):
param = {'nlist': 16384}
# status = client.create_index(table_name, param)
status = client.create_index(table_name, IndexType.IVF_FLAT, param)
return status
def delete_table(client, table_name):
status = client.drop_collection(collection_name=table_name)
print(status)
return status
def search_vectors(client, table_name, vectors, top_k):
search_param = {'nprobe': 16}
status, res = client.search(collection_name=table_name, query_records=vectors, top_k=top_k, params=search_param)
return status, res
def has_table(client, table_name):
status = client.has_collection(collection_name=table_name)
return status
def count_table(client, table_name):
status, num = client.count_entities(collection_name=table_name)
return num
\ No newline at end of file
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
from numpy import linalg as LA
from common.const import input_shape
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling,
include_top=False)
self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
def vgg_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
norm_feat = feat[0] / LA.norm(feat[0])
norm_feat = [i.item() for i in norm_feat]
return norm_feat
def vgg_extract_feat(img_path, model, graph, sess):
with sess.as_default():
with graph.as_default():
img = image.load_img(img_path, target_size=(input_shape[0], input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = model.predict(img)
norm_feat = feat[0] / LA.norm(feat[0])
norm_feat = [i.item() for i in norm_feat]
return norm_feat
import logging
logging.basicConfig(filename='app.log', filemode='w', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
# import tensorflow as tf
# config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = 0.5
# sess = tf.Session(config=config)
import logging
import time
from common.config import DEFAULT_TABLE
from common.const import default_cache_dir
# from common.config import DATA_PATH as database_path
from encoder.encode import feature_extract
from preprocessor.vggnet import VGGNet
from diskcache import Cache
from indexer.index import milvus_client, create_table, insert_vectors, delete_table, search_vectors, create_index, count_table
def do_count(table_name):
if not table_name:
table_name = DEFAULT_TABLE
try:
index_client = milvus_client()
print("get table rows:",table_name)
num = count_table(index_client, table_name=table_name)
return num
except Exception as e:
logging.error(e)
return "Error with {}".format(e)
import logging
import time
from common.config import DEFAULT_TABLE
from common.const import default_cache_dir
# from common.config import DATA_PATH as database_path
from encoder.encode import feature_extract
from preprocessor.vggnet import VGGNet
from diskcache import Cache
from indexer.index import milvus_client, create_table, insert_vectors, delete_table, search_vectors, create_index,delete_table
def do_delete(table_name):
if not table_name:
table_name = DEFAULT_TABLE
try:
index_client = milvus_client()
status = delete_table(index_client, table_name=table_name)
return status
except Exception as e:
logging.error(e)
return "Error with {}".format(e)
import logging
from common.const import default_cache_dir
from indexer.index import milvus_client, create_table, insert_vectors, delete_table, search_vectors, create_index
from preprocessor.vggnet import VGGNet
from preprocessor.vggnet import vgg_extract_feat
from diskcache import Cache
def query_name_from_ids(vids):
res = []
cache = Cache(default_cache_dir)
for i in vids:
if i in cache:
res.append(cache[i])
return res
def do_search(table_name, img_path, top_k, model, graph, sess):
try:
feats = []
index_client = milvus_client()
feat = vgg_extract_feat(img_path, model, graph, sess)
feats.append(feat)
_, vectors = search_vectors(index_client, table_name, feats, top_k)
vids = [x.id for x in vectors[0]]
# print(vids)
# res = [x.decode('utf-8') for x in query_name_from_ids(vids)]
res_id = [x.decode('utf-8') for x in query_name_from_ids(vids)]
# print(res_id)
res_distance = [x.distance for x in vectors[0]]
# print(res_distance)
# res = dict(zip(res_id,distance))
return res_id,res_distance
except Exception as e:
logging.error(e)
return "Fail with error {}".format(e)
import threading
from concurrent.futures import ThreadPoolExecutor
from service.train import do_train
def thread_runner(thread_num, func, *args):
executor = ThreadPoolExecutor(thread_num)
f = executor.submit(do_train, *args)
import logging
import time
from common.config import DEFAULT_TABLE
from common.const import default_cache_dir
# from common.config import DATA_PATH as database_path
from encoder.encode import feature_extract
from preprocessor.vggnet import VGGNet
from diskcache import Cache
from indexer.index import milvus_client, create_table, insert_vectors, delete_table, search_vectors, create_index,has_table
def do_train(table_name, database_path):
if not table_name:
table_name = DEFAULT_TABLE
cache = Cache(default_cache_dir)
try:
vectors, names = feature_extract(database_path, VGGNet())
index_client = milvus_client()
# delete_table(index_client, table_name=table_name)
# time.sleep(1)
status, ok = has_table(index_client, table_name)
if not ok:
print("create table.")
create_table(index_client, table_name=table_name)
print("insert into:", table_name)
status, ids = insert_vectors(index_client, table_name, vectors)
create_index(index_client, table_name)
for i in range(len(names)):
# cache[names[i]] = ids[i]
cache[ids[i]] = names[i]
print("Train finished")
return "Train finished"
except Exception as e:
logging.error(e)
return "Error with {}".format(e)
...@@ -88,7 +88,7 @@ ...@@ -88,7 +88,7 @@
box-shadow: 2px 2px 10px #FF9F2A; box-shadow: 2px 2px 10px #FF9F2A;
} }
.x { .__x11 {
width: 20px; width: 20px;
height: 20px; height: 20px;
background-color: white; background-color: white;
......
...@@ -48,7 +48,7 @@ function openUi() { ...@@ -48,7 +48,7 @@ function openUi() {
$(closeBtnDiv).click(function() { $(closeBtnDiv).click(function() {
toggleButton(); toggleButton();
}); });
$(closeBtnDiv).addClass("x"); $(closeBtnDiv).addClass("__x11");
closeBtnDiv.append(closeBtn); closeBtnDiv.append(closeBtn);
body.append(closeBtnDiv); body.append(closeBtnDiv);
body.append(container); body.append(container);
...@@ -94,7 +94,7 @@ function openUi() { ...@@ -94,7 +94,7 @@ function openUi() {
function closeUi() { function closeUi() {
$(container).remove(); $(container).remove();
$(".x").remove(); $(".__x11").remove();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment