ndebuhr commited on
Commit
3c30e6f
1 Parent(s): 0bb15e0

Setup the app code, requirements, and metadata

Browse files
Files changed (3) hide show
  1. README.md +5 -5
  2. app.py +139 -0
  3. requirements.txt +6 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Streaming Llm Weather Alerts
3
- emoji:
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
1
  ---
2
+ title: Streaming LLM Weather Alerts
3
+ emoji: 🌤️
4
+ colorFrom: pink
5
+ colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
app.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import apache_beam as beam
2
+ import gradio as gr
3
+ import huggingface_hub
4
+ import pandas as pd
5
+ import plotly.graph_objects as go
6
+ import spaces
7
+ import textwrap
8
+ import torch
9
+ import us
10
+
11
+ from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ import json
15
+ import logging
16
+ import os
17
+ import requests
18
+
19
+ MODEL_NAME = "google/gemma-2-2b-it"
20
+ PROMPT_TEMPLATE = """Write a succinct summary of the following weather alerts. Do not comment on missing information - just summarize the information provided/available.
21
+
22
+ ```json
23
+ {}
24
+ ```
25
+
26
+ Summary (In the state...):
27
+ """
28
+
29
+ # Initialize an empty list to store weather alerts
30
+ alerts = []
31
+
32
+
33
+ # Define a transform for fetching weather alerts
34
+ class FetchWeatherAlerts(beam.DoFn):
35
+ def process(self, state):
36
+ logging.info(f"Fetching weather alerts for {state} from weather.gov")
37
+ url = f"https://api.weather.gov/alerts/active?area={state}"
38
+ response = requests.get(
39
+ url,
40
+ headers={
41
+ "User-Agent": "(Neal DeBuhr, https://huggingface.co/spaces/ndebuhr/streaming-llm-weather-alerts)",
42
+ "Accept": "application/geo+json",
43
+ },
44
+ )
45
+ if response.status_code == 200:
46
+ logging.info(f"Fetched weather alerts for {state} from weather.gov")
47
+ features = response.json()["features"]
48
+ alerts.append(
49
+ {
50
+ "features": [
51
+ {
52
+ "event": feature["properties"]["event"],
53
+ "headline": feature["properties"]["headline"],
54
+ "instruction": feature["properties"]["instruction"],
55
+ }
56
+ for feature in features
57
+ if feature["properties"]["messageType"] == "Alert"
58
+ ],
59
+ "state": state,
60
+ }
61
+ )
62
+
63
+
64
+ pipeline_options = PipelineOptions()
65
+ # Save the main session state so that pickled functions and classes
66
+ # defined in __main__ can be unpickled
67
+ pipeline_options.view_as(SetupOptions).save_main_session = True
68
+
69
+ # Create and run the Apache Beam pipeline to fetch weather alerts
70
+ with beam.Pipeline(options=pipeline_options) as p:
71
+ (p
72
+ | "Create States" >> beam.Create([state.abbr for state in us.states.STATES])
73
+ | "Fetch Weather Alerts" >> beam.ParDo(FetchWeatherAlerts())
74
+ )
75
+
76
+
77
+ # Define a function to generate alert summaries using transformers and ZeroGPU
78
+ @spaces.GPU(duration=300)
79
+ def generate_summaries(alerts):
80
+ huggingface_hub.login(token=os.environ["HUGGINGFACE_TOKEN"])
81
+ device = torch.device("cuda")
82
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
83
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
84
+ for alert in alerts:
85
+ prompt = PROMPT_TEMPLATE.format(json.dumps(alert, indent=2))
86
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
87
+
88
+ with torch.no_grad():
89
+ outputs = model.generate(
90
+ **inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id
91
+ )
92
+
93
+ alert["summary"] = (
94
+ tokenizer.decode(outputs[0], skip_special_tokens=True)
95
+ .replace(prompt, "")
96
+ .strip()
97
+ )
98
+ return alerts
99
+
100
+
101
+ alerts = generate_summaries(alerts)
102
+
103
+ df = pd.DataFrame.from_dict(
104
+ [{"state": alert["state"], "summary": alert["summary"]} for alert in alerts]
105
+ )
106
+
107
+
108
+ def get_map():
109
+ def wrap_text(text, width=50):
110
+ return "<br>".join(textwrap.wrap(text, width=width))
111
+
112
+ df["wrapped_summary"] = df["summary"].apply(wrap_text)
113
+
114
+ fig = go.Figure(
115
+ go.Choropleth(
116
+ locations=df["state"],
117
+ z=[1 for _ in df["summary"]],
118
+ locationmode="USA-states",
119
+ colorscale=[
120
+ [0, "lightgrey"],
121
+ [1, "lightgrey"],
122
+ ], # Single color for all states
123
+ showscale=False,
124
+ text=df["wrapped_summary"],
125
+ hoverinfo="text",
126
+ hovertemplate="%{text}<extra></extra>",
127
+ )
128
+ )
129
+
130
+ fig.update_layout(title_text="Streaming LLM Weather Alerts", geo_scope="usa")
131
+
132
+ return fig
133
+
134
+
135
+ # Create Gradio interface
136
+ iface = gr.Interface(fn=get_map, inputs=None, outputs=gr.Plot())
137
+
138
+ # Launch the Gradio interface
139
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ apache_beam==2.57.0
2
+ huggingface_hub==0.24.5
3
+ pandas==2.2.2
4
+ plotly==5.23.0
5
+ transformers==4.43.4
6
+ us==3.2.0