zzz99 commited on
Commit
65520f9
1 Parent(s): 15b45df

add handler and requirements

Browse files
Files changed (2) hide show
  1. handler.py +49 -0
  2. requirements.txt +156 -0
handler.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
5
+
6
+ # from peft import PeftConfig, PeftModel
7
+
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path=""):
11
+ # load model and processor from path
12
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
13
+ # try:
14
+ # config = AutoConfig.from_pretrained(path)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ path,
17
+ # return_dict=True,
18
+ # load_in_8bit=True,
19
+ device_map="auto",
20
+ torch_dtype=torch.float16,
21
+ # trust_remote_code=True,
22
+ )
23
+ # model.resize_token_embeddings(len(self.tokenizer))
24
+ # model = PeftModel.from_pretrained(model, path)
25
+ # except Exception:
26
+ # model = AutoModelForCausalLM.from_pretrained(
27
+ # path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16, trust_remote_code=True
28
+ # )
29
+ self.model = model
30
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
33
+ # process input
34
+ inputs = data.pop("inputs", data)
35
+ parameters = data.pop("parameters", None)
36
+
37
+ # preprocess
38
+ inputs = self.tokenizer(f"User: {inputs}\n\n", return_tensors="pt")
39
+
40
+ # pass inputs with all kwargs in data
41
+ if parameters is not None:
42
+ outputs = self.model.generate(**inputs.to(self.device), max_new_tokens=880, **parameters)
43
+ else:
44
+ outputs = self.model.generate(**inputs.to(self.device), max_new_tokens=880)
45
+
46
+ # postprocess the prediction
47
+ prediction = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
48
+
49
+ return [{"generated_text": prediction}]
requirements.txt ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.26.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ annotated-types==0.6.0
5
+ anyio==4.2.0
6
+ appdirs==1.4.4
7
+ argon2-cffi==23.1.0
8
+ argon2-cffi-bindings==21.2.0
9
+ arrow==1.3.0
10
+ asttokens==2.4.1
11
+ async-lru==2.0.4
12
+ attrs==23.2.0
13
+ Babel==2.14.0
14
+ beautifulsoup4==4.12.3
15
+ bitsandbytes==0.42.0
16
+ bleach==6.1.0
17
+ certifi==2024.2.2
18
+ cffi==1.16.0
19
+ charset-normalizer==3.3.2
20
+ click==8.1.7
21
+ comm==0.2.1
22
+ datasets==2.16.1
23
+ debugpy==1.8.0
24
+ decorator==5.1.1
25
+ deepspeed==0.13.1
26
+ defusedxml==0.7.1
27
+ dill==0.3.7
28
+ docker-pycreds==0.4.0
29
+ executing==2.0.1
30
+ fastjsonschema==2.19.1
31
+ filelock==3.13.1
32
+ fqdn==1.5.1
33
+ frozenlist==1.4.1
34
+ fsspec==2023.10.0
35
+ gitdb==4.0.11
36
+ GitPython==3.1.41
37
+ h11==0.14.0
38
+ hf_transfer==0.1.5
39
+ hjson==3.1.0
40
+ httpcore==1.0.2
41
+ httpx==0.26.0
42
+ huggingface-hub==0.20.3
43
+ idna==3.6
44
+ ipykernel==6.29.1
45
+ ipython==8.21.0
46
+ ipywidgets==8.1.1
47
+ isoduration==20.11.0
48
+ jedi==0.19.1
49
+ Jinja2==3.1.3
50
+ json5==0.9.14
51
+ jsonpointer==2.4
52
+ jsonschema==4.21.1
53
+ jsonschema-specifications==2023.12.1
54
+ jupyter-events==0.9.0
55
+ jupyter-lsp==2.2.2
56
+ jupyter_client==8.6.0
57
+ jupyter_core==5.7.1
58
+ jupyter_server==2.12.5
59
+ jupyter_server_terminals==0.5.2
60
+ jupyterlab==4.1.0
61
+ jupyterlab-widgets==3.0.9
62
+ jupyterlab_pygments==0.3.0
63
+ jupyterlab_server==2.25.2
64
+ MarkupSafe==2.1.5
65
+ matplotlib-inline==0.1.6
66
+ mistune==3.0.2
67
+ mpmath==1.3.0
68
+ multidict==6.0.5
69
+ multiprocess==0.70.15
70
+ nbclient==0.9.0
71
+ nbconvert==7.15.0
72
+ nbformat==5.9.2
73
+ nest-asyncio==1.6.0
74
+ networkx==3.2.1
75
+ ninja==1.11.1.1
76
+ notebook==7.0.7
77
+ notebook_shim==0.2.3
78
+ numpy==1.26.4
79
+ nvidia-cublas-cu12==12.1.3.1
80
+ nvidia-cuda-cupti-cu12==12.1.105
81
+ nvidia-cuda-nvrtc-cu12==12.1.105
82
+ nvidia-cuda-runtime-cu12==12.1.105
83
+ nvidia-cudnn-cu12==8.9.2.26
84
+ nvidia-cufft-cu12==11.0.2.54
85
+ nvidia-curand-cu12==10.3.2.106
86
+ nvidia-cusolver-cu12==11.4.5.107
87
+ nvidia-cusparse-cu12==12.1.0.106
88
+ nvidia-nccl-cu12==2.19.3
89
+ nvidia-nvjitlink-cu12==12.3.101
90
+ nvidia-nvtx-cu12==12.1.105
91
+ overrides==7.7.0
92
+ packaging==23.2
93
+ pandas==2.2.0
94
+ pandocfilters==1.5.1
95
+ parso==0.8.3
96
+ peft==0.8.2
97
+ pexpect==4.9.0
98
+ platformdirs==4.2.0
99
+ prometheus-client==0.19.0
100
+ prompt-toolkit==3.0.43
101
+ protobuf==4.25.2
102
+ psutil==5.9.8
103
+ ptyprocess==0.7.0
104
+ pure-eval==0.2.2
105
+ py-cpuinfo==9.0.0
106
+ pyarrow==15.0.0
107
+ pyarrow-hotfix==0.6
108
+ pycparser==2.21
109
+ pydantic==2.6.1
110
+ pydantic_core==2.16.2
111
+ Pygments==2.17.2
112
+ pynvml==11.5.0
113
+ python-dateutil==2.8.2
114
+ python-json-logger==2.0.7
115
+ pytz==2024.1
116
+ PyYAML==6.0.1
117
+ pyzmq==25.1.2
118
+ referencing==0.33.0
119
+ regex==2023.12.25
120
+ requests==2.31.0
121
+ rfc3339-validator==0.1.4
122
+ rfc3986-validator==0.1.1
123
+ rpds-py==0.17.1
124
+ safetensors==0.4.2
125
+ scipy==1.12.0
126
+ Send2Trash==1.8.2
127
+ sentry-sdk==1.40.2
128
+ setproctitle==1.3.3
129
+ six==1.16.0
130
+ smmap==5.0.1
131
+ sniffio==1.3.0
132
+ soupsieve==2.5
133
+ stack-data==0.6.3
134
+ sympy==1.12
135
+ terminado==0.18.0
136
+ tinycss2==1.2.1
137
+ tokenizers==0.15.1
138
+ torch==2.2.0
139
+ tornado==6.4
140
+ tqdm==4.66.1
141
+ traitlets==5.14.1
142
+ transformers==4.37.2
143
+ triton==2.2.0
144
+ types-python-dateutil==2.8.19.20240106
145
+ typing_extensions==4.9.0
146
+ tzdata==2023.4
147
+ uri-template==1.3.0
148
+ urllib3==2.2.0
149
+ wandb==0.16.3
150
+ wcwidth==0.2.13
151
+ webcolors==1.13
152
+ webencodings==0.5.1
153
+ websocket-client==1.7.0
154
+ widgetsnbextension==4.0.9
155
+ xxhash==3.4.1
156
+ yarl==1.9.4