ydshieh
commited on
Commit
•
a722530
1
Parent(s):
3c34a2e
upload coco summary script
Browse files- run_summarization_coco.py +11 -5
run_summarization_coco.py
CHANGED
@@ -37,6 +37,7 @@ import nltk # Here to have a nice missing dependency error message early on
|
|
37 |
import numpy as np
|
38 |
from datasets import Dataset, load_dataset, load_metric
|
39 |
from tqdm import tqdm
|
|
|
40 |
|
41 |
import jax
|
42 |
import jax.numpy as jnp
|
@@ -418,19 +419,24 @@ def main():
|
|
418 |
|
419 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
420 |
def preprocess_function(examples):
|
421 |
-
|
422 |
_pixel_values = []
|
423 |
-
|
|
|
424 |
with Image.open(y) as image:
|
425 |
-
|
|
|
|
|
|
|
426 |
x = encoder_inputs.pixel_values
|
427 |
_pixel_values.append(x)
|
|
|
428 |
pixel_values = np.concatenate(_pixel_values)
|
429 |
|
430 |
-
targets =
|
431 |
|
432 |
# Add eos_token!!
|
433 |
-
targets = [x
|
434 |
|
435 |
model_inputs = {}
|
436 |
model_inputs['pixel_values'] = pixel_values
|
|
|
37 |
import numpy as np
|
38 |
from datasets import Dataset, load_dataset, load_metric
|
39 |
from tqdm import tqdm
|
40 |
+
from PIL import Image
|
41 |
|
42 |
import jax
|
43 |
import jax.numpy as jnp
|
|
|
419 |
|
420 |
# Setting padding="max_length" as we need fixed length inputs for jitted functions
|
421 |
def preprocess_function(examples):
|
422 |
+
|
423 |
_pixel_values = []
|
424 |
+
_captions = []
|
425 |
+
for y, z in zip(examples[image_file_column], examples[caption_column]):
|
426 |
with Image.open(y) as image:
|
427 |
+
try:
|
428 |
+
encoder_inputs = feature_extractor(images=image, return_tensors="np")
|
429 |
+
except:
|
430 |
+
continue
|
431 |
x = encoder_inputs.pixel_values
|
432 |
_pixel_values.append(x)
|
433 |
+
_captions.append(z + ' ' + tokenizer.eos_token)
|
434 |
pixel_values = np.concatenate(_pixel_values)
|
435 |
|
436 |
+
targets = _captions
|
437 |
|
438 |
# Add eos_token!!
|
439 |
+
#targets = [x + ' ' + tokenizer.eos_token for x in targets]
|
440 |
|
441 |
model_inputs = {}
|
442 |
model_inputs['pixel_values'] = pixel_values
|