aolko commited on
Commit
870b8a1
1 Parent(s): 6eac492

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
4
  from diffusers import DiffusionPipeline
5
  import requests
6
  from PIL import Image
@@ -40,8 +40,7 @@ def transcribe_image(image, image_type, transcriber, booru_tags=None):
40
  tags = [anime_labels[i] for i in top_indices]
41
  else:
42
  prompt = "<MORE_DETAILED_CAPTION>"
43
- inputs = processor(text=prompt, images=image, return_tensors="pt")
44
-
45
  generated_ids = photo_model.generate(
46
  input_ids=inputs["input_ids"],
47
  pixel_values=inputs["pixel_values"],
@@ -49,14 +48,9 @@ def transcribe_image(image, image_type, transcriber, booru_tags=None):
49
  do_sample=False,
50
  num_beams=3,
51
  )
52
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
53
- parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
54
 
55
- # Extract tags from parsed_answer
56
- tags = [obj['class'] for obj in parsed_answer]
57
-
58
- if booru_tags:
59
- tags = list(set(tags + booru_tags))
60
 
61
  return ", ".join(tags)
62
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModelForCasualLLM
4
  from diffusers import DiffusionPipeline
5
  import requests
6
  from PIL import Image
 
40
  tags = [anime_labels[i] for i in top_indices]
41
  else:
42
  prompt = "<MORE_DETAILED_CAPTION>"
43
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
 
44
  generated_ids = photo_model.generate(
45
  input_ids=inputs["input_ids"],
46
  pixel_values=inputs["pixel_values"],
 
48
  do_sample=False,
49
  num_beams=3,
50
  )
51
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
52
 
53
+ tags = generated_text # Use generated text as the description
 
 
 
 
54
 
55
  return ", ".join(tags)
56