mattraj commited on
Commit
5256208
1 Parent(s): 992a3b7

lock versions

Browse files
Files changed (2) hide show
  1. app.py +5 -7
  2. requirements.txt +4 -6
app.py CHANGED
@@ -7,9 +7,6 @@ import os
7
  import string
8
  import functools
9
  import re
10
- import flax.linen as nn
11
- import jax
12
- import jax.numpy as jnp
13
  import numpy as np
14
  import spaces
15
  from PIL import Image
@@ -50,10 +47,11 @@ def infer(
50
  max_new_tokens: int
51
  ) -> str:
52
  inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
53
- generated_ids = model.generate(
54
- **inputs,
55
- max_length=2048
56
- )
 
57
  result = processor.decode(generated_ids[0], skip_special_tokens=True)
58
  return result
59
 
 
7
  import string
8
  import functools
9
  import re
 
 
 
10
  import numpy as np
11
  import spaces
12
  from PIL import Image
 
47
  max_new_tokens: int
48
  ) -> str:
49
  inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
50
+ with torch.no_grad():
51
+ generated_ids = model.generate(
52
+ **inputs,
53
+ max_length=2048
54
+ )
55
  result = processor.decode(generated_ids[0], skip_special_tokens=True)
56
  return result
57
 
requirements.txt CHANGED
@@ -1,9 +1,7 @@
1
  huggingface_hub
2
  gradio
3
- pillow
4
- transformers
5
- torch
6
- flax
7
- jax
8
- numpy
9
  spaces
 
1
  huggingface_hub
2
  gradio
3
+ pillow==10.3.0
4
+ transformers==4.41.1
5
+ torch==2.3.0
6
+ numpy==1.26.4
 
 
7
  spaces