File size: 4,915 Bytes
dc65e63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# fastapi_crud/app/main.py
from fastapi import FastAPI, File, UploadFile, HTTPException, Form, Depends
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from pydub import AudioSegment
import io
import spacy
import speech_recognition as sr

from app.database import engine, Base, get_db
from app.routers import user, device
from app import crud, schemas, auth, models

# Create the database tables
Base.metadata.create_all(bind=engine)

app = FastAPI()

app.include_router(user.router)
app.include_router(device.router)

# Load spaCy models
nlp = spacy.load("custom_nlp_model")
nlp2 = spacy.load("text_categorizer_model")

def convert_audio_to_text(audio_file: UploadFile):
    try:
        audio_format = audio_file.filename.split(".")[-1]
        if audio_format not in ["wav", "mp3", "ogg", "flac"]:
            raise HTTPException(status_code=400, detail="Unsupported audio format. Please upload a wav, mp3, ogg, or flac file.")
        
        audio = AudioSegment.from_file(io.BytesIO(audio_file.file.read()), format=audio_format)
        audio = audio.set_channels(1).set_frame_rate(16000)
        wav_io = io.BytesIO()
        audio.export(wav_io, format="wav")
        wav_io.seek(0)
        
        recognizer = sr.Recognizer()
        with sr.AudioFile(wav_io) as source:
            audio_data = recognizer.record(source)
        
        text = recognizer.recognize_google(audio_data)
        return text
    except sr.UnknownValueError:
        raise HTTPException(status_code=400, detail="Speech recognition could not understand the audio.")
    except sr.RequestError as e:
        raise HTTPException(status_code=500, detail=f"Speech recognition service error: {e}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Audio processing error: {e}")

def update_device_status(db: Session, device_name: str, location: str, status: str):
    db_device = db.query(models.Device).filter(models.Device.name == device_name, models.Device.location == location).first()
    active_status = True if status.lower() == "on" else False
    if db_device:
        crud.set_device_active(db, db_device.id, active_status)
        return {"device": db_device.name, "location": db_device.location, "status": "turned " + status}
    return {"device": device_name, "location": location, "status": "not found"}

def find_location(db: Session, text: str):
    words = text.split()
    for word in words:
        db_location = db.query(models.Device).filter(models.Device.location == word).first()
        if db_location:
            return db_location.location
    return ''

def process_entities_and_update(db: Session, doc, text, status):
    updates = []
    location_entity = next((ent for ent in doc.ents if ent.label_.lower() == 'location'), None)
    location = location_entity.text if location_entity else find_location(db, text)
    
    for ent in doc.ents:
        if ent.label_ == 'device':
            update = update_device_status(db, ent.text, location, status)
            updates.append(update)

    if not updates:  # No device entities found, process all words as potential device names
        words = text.split()
        for word in words:
            update = update_device_status(db, word, location, status)
            updates.append(update)

    return updates

@app.post("/predict/")
async def predict(audio_file: UploadFile = File(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)):
    try:
        text = convert_audio_to_text(audio_file)
        doc = nlp(text)
        doc2 = nlp2(text)

        predictions = {"category": max(doc2.cats, key=doc2.cats.get)}
        entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]

        updates = process_entities_and_update(db, doc, text, predictions['category'])

        return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates})
    except HTTPException as e:
        return JSONResponse(status_code=e.status_code, content={"detail": e.detail})

@app.post("/predict_text/")
async def predict_text(text: str = Form(...), db: Session = Depends(get_db), current_user: schemas.User = Depends(auth.get_current_user)):
    try:
        doc = nlp(text)
        doc2 = nlp2(text)

        predictions = {"category": max(doc2.cats, key=doc2.cats.get)}
        entities = [{"text": ent.text, "label": ent.label_} for ent in doc.ents]

        updates = process_entities_and_update(db, doc, text, predictions['category'])

        return JSONResponse(content={"text": text, "predictions": predictions, "entities": entities, "updates": updates})
    except HTTPException as e:
        return JSONResponse(status_code=e.status_code, content={"detail": e.detail})