aolko commited on
Commit
6eac492
1 Parent(s): 55ff40c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -23
app.py CHANGED
@@ -11,13 +11,13 @@ from huggingface_hub import hf_hub_download
11
  # Initialize models
12
  anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
13
  anime_model = ort.InferenceSession(anime_model_path)
14
- photo_model = AutoModelForZeroShotImageClassification.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
15
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
16
 
17
  # Load labels for the anime model
18
  labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
19
  with open(labels_path, 'r') as f:
20
- labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header
21
 
22
  def preprocess_image(image):
23
  image = image.convert('RGB')
@@ -28,6 +28,39 @@ def preprocess_image(image):
28
  image = image / 255.0
29
  return image[np.newaxis, ...]
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_booru_image(booru, image_id):
32
  if booru == "Gelbooru":
33
  url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
@@ -51,27 +84,6 @@ def get_booru_image(booru, image_id):
51
 
52
  return img, tags
53
 
54
- def transcribe_image(image, image_type, transcriber, booru_tags=None):
55
- if image_type == "Anime":
56
- input_image = preprocess_image(image)
57
- input_name = anime_model.get_inputs()[0].name
58
- output_name = anime_model.get_outputs()[0].name
59
- probs = anime_model.run([output_name], {input_name: input_image})[0]
60
-
61
- # Get top 50 tags
62
- top_indices = probs[0].argsort()[-50:][::-1]
63
- tags = [labels[i] for i in top_indices]
64
- else:
65
- inputs = processor(images=image, return_tensors="pt")
66
- outputs = photo_model(**inputs)
67
- tags = outputs.logits.topk(50).indices.squeeze().tolist()
68
- tags = [processor.config.id2label[t] for t in tags]
69
-
70
- if booru_tags:
71
- tags = list(set(tags + booru_tags))
72
-
73
- return ", ".join(tags)
74
-
75
  def update_image(image_type, booru, image_id, uploaded_image):
76
  if image_type == "Anime" and booru != "Upload":
77
  image, booru_tags = get_booru_image(booru, image_id)
 
11
  # Initialize models
12
  anime_model_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "model.onnx")
13
  anime_model = ort.InferenceSession(anime_model_path)
14
+ photo_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
15
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
16
 
17
  # Load labels for the anime model
18
  labels_path = hf_hub_download("SmilingWolf/wd-convnext-tagger-v3", "selected_tags.csv")
19
  with open(labels_path, 'r') as f:
20
+ anime_labels = [line.strip().split(',')[0] for line in f.readlines()[1:]] # Skip header
21
 
22
  def preprocess_image(image):
23
  image = image.convert('RGB')
 
28
  image = image / 255.0
29
  return image[np.newaxis, ...]
30
 
31
+ def transcribe_image(image, image_type, transcriber, booru_tags=None):
32
+ if image_type == "Anime":
33
+ input_image = preprocess_image(image)
34
+ input_name = anime_model.get_inputs()[0].name
35
+ output_name = anime_model.get_outputs()[0].name
36
+ probs = anime_model.run([output_name], {input_name: input_image})[0]
37
+
38
+ # Get top 50 tags
39
+ top_indices = probs[0].argsort()[-50:][::-1]
40
+ tags = [anime_labels[i] for i in top_indices]
41
+ else:
42
+ prompt = "<MORE_DETAILED_CAPTION>"
43
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
44
+
45
+ generated_ids = photo_model.generate(
46
+ input_ids=inputs["input_ids"],
47
+ pixel_values=inputs["pixel_values"],
48
+ max_new_tokens=1024,
49
+ do_sample=False,
50
+ num_beams=3,
51
+ )
52
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
53
+ parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
54
+
55
+ # Extract tags from parsed_answer
56
+ tags = [obj['class'] for obj in parsed_answer]
57
+
58
+ if booru_tags:
59
+ tags = list(set(tags + booru_tags))
60
+
61
+ return ", ".join(tags)
62
+
63
+
64
  def get_booru_image(booru, image_id):
65
  if booru == "Gelbooru":
66
  url = f"https://gelbooru.com/index.php?page=dapi&s=post&q=index&json=1&id={image_id}"
 
84
 
85
  return img, tags
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def update_image(image_type, booru, image_id, uploaded_image):
88
  if image_type == "Anime" and booru != "Upload":
89
  image, booru_tags = get_booru_image(booru, image_id)