lllchenlll commited on
Commit
054d7f8
1 Parent(s): 3425021

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -5
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
 
2
  import openai
3
 
4
  from sentence_transformers import SentenceTransformer
5
  from langchain.prompts import PromptTemplate
 
6
 
7
 
8
  def process(api, caption, category, asr, ocr):
@@ -10,19 +12,69 @@ def process(api, caption, category, asr, ocr):
10
  preference = "兴趣标签"
11
  example = "例如,给定一个视频,它的\"标题\"为\"长安系最便宜的轿车,4W起很多人都看不上它,但我知道车只是代步工具,又需要什么面子呢" \
12
  "!\",\"类别\"为\"汽车\",\"ocr\"为\"长安系最便宜的一款轿车\",\"asr\"为\"我不否认现在的国产和合资还有一定的差距," \
13
- "但确实是他们让我们5万开了MP V8万开上了轿车,10万开张了ICV15万开张了大七座。\"{}生成机器人推断出合理的\"{}\"为\"" \
14
- "安轿车报价、最便宜的长安轿车、新款长安轿车\"。".format(preference, preference)
15
 
16
  prompt = PromptTemplate(
17
  input_variables=["preference", "caption", "ocr", "asr", "category", "example"],
18
- template="你是一个视频的{preference}生成机器人,根据输入的视频标题、类别、ocr、asr推理出合理的\"{preference}\",以多个多"
19
  "于两字的标签形式进行表达,以顿号隔开。{example}那么,给定一个新的视频,它的\"标题\"为\"{caption}\",\"类别\"为"
20
  "\"{category}\",\"ocr\"为\"{ocr}\",\"asr\"为\"{asr}\",请推断出该视频的\"{preference}\":"
21
  )
22
 
23
  text = prompt.format(preference=preference, caption=caption, category=category, ocr=ocr, asr=asr, example=example)
24
 
25
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  with gr.Blocks() as demo:
@@ -39,4 +91,4 @@ with gr.Blocks() as demo:
39
 
40
 
41
  if __name__ == "__main__":
42
- demo.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
  import openai
4
 
5
  from sentence_transformers import SentenceTransformer
6
  from langchain.prompts import PromptTemplate
7
+ from collections import Counter
8
 
9
 
10
  def process(api, caption, category, asr, ocr):
 
12
  preference = "兴趣标签"
13
  example = "例如,给定一个视频,它的\"标题\"为\"长安系最便宜的轿车,4W起很多人都看不上它,但我知道车只是代步工具,又需要什么面子呢" \
14
  "!\",\"类别\"为\"汽车\",\"ocr\"为\"长安系最便宜的一款轿车\",\"asr\"为\"我不否认现在的国产和合资还有一定的差距," \
15
+ "但确实是他们让我们5万开了MP V8万开上了轿车,10万开张了ICV15万开张了大七座。\",\"{}\"生成机器人推断出合理的\"{}\"为\"" \
16
+ "长安轿车报价、最便宜的长安轿车、新款长安轿车\"。".format(preference, preference)
17
 
18
  prompt = PromptTemplate(
19
  input_variables=["preference", "caption", "ocr", "asr", "category", "example"],
20
+ template="你是一个视频的\"{preference}\"生成机器人,根据输入的视频标题、类别、ocr、asr推理出合理的\"{preference}\",以多个多"
21
  "于两字的标签形式进行表达,以顿号隔开。{example}那么,给定一个新的视频,它的\"标题\"为\"{caption}\",\"类别\"为"
22
  "\"{category}\",\"ocr\"为\"{ocr}\",\"asr\"为\"{asr}\",请推断出该视频的\"{preference}\":"
23
  )
24
 
25
  text = prompt.format(preference=preference, caption=caption, category=category, ocr=ocr, asr=asr, example=example)
26
 
27
+ try:
28
+ completion = openai.ChatCompletion.create(
29
+ model="gpt-3.5-turbo",
30
+ messages=[{"role": "user", "content": text}],
31
+ temperature=1.5,
32
+ n=5
33
+ )
34
+
35
+ res = []
36
+ for j in range(5):
37
+ ans = completion.choices[j].message["content"].strip()
38
+ ans = ans.replace("\n", "")
39
+ ans = ans.replace("。", "")
40
+ ans = ans.replace(",", "、")
41
+ res += ans.split('、')
42
+
43
+ tag_count = Counter(res)
44
+ tag_count = sorted(tag_count.items(), key=lambda x: x[1], reverse=True)[:10]
45
+
46
+ tags_embed = np.load('./tag_data/tags_embed.npy')
47
+ tags_dis = np.load('./tag_data/tags_dis.npy')
48
+
49
+ candidate_tags = [_[0] for _ in tag_count]
50
+ encoder = SentenceTransformer("hfl/chinese-roberta-wwm-ext-large")
51
+ candidate_tags_embed = encoder.encode(candidate_tags)
52
+ candidate_tags_dis = [np.sqrt(np.dot(_, _.T)) for _ in candidate_tags_embed]
53
+
54
+ scores = np.dot(candidate_tags_embed, tags_embed.T)
55
+ f = open('./tag_data/tags.txt', 'r')
56
+ all_tags = []
57
+ for line in f.readlines():
58
+ all_tags.append(line.strip())
59
+ f.close()
60
+
61
+ final_ans = []
62
+ for i in range(scores.shape[0]):
63
+ for j in range(scores.shape[1]):
64
+ score = scores[i][j] / (candidate_tags_dis[i] * tags_dis[j])
65
+ if score > 0.8:
66
+ final_ans.append(all_tags[j])
67
+
68
+ print(final_ans)
69
+
70
+ final_ans = Counter(final_ans)
71
+ final_ans = sorted(final_ans.items(), key=lambda x: x[1], reverse=True)[:5]
72
+ final_ans = [_[0] for _ in final_ans]
73
+
74
+ return "、".join(final_ans)
75
+
76
+ except:
77
+ return 'api error'
78
 
79
 
80
  with gr.Blocks() as demo:
 
91
 
92
 
93
  if __name__ == "__main__":
94
+ demo.launch(share=True)