File size: 1,868 Bytes
a74e89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5169ca2
a74e89c
 
 
8b74918
 
a74e89c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import math
import json
import regex
import inspect
from ast import literal_eval
from transformers import Tool

from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build


class DataExtractorTool(Tool):
    def __init__(self):
        super().__init__()

    def fetch_data(self, source_link: str) -> str:
        """
        Abstract method to fetch raw data from the source.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    def __call__(self, source_link: str) -> dict:
        raw_data = self.fetch_data(source_link)
        return raw_data


class GoogleSheetExtractorTool(DataExtractorTool):
    name = "google_sheet_extractor_tool"
    description = """
    Tool to extract data from Google Sheets.
    Input is source_link which is a str of a url or google sheets id
    e.x. source_link='https://docs.google.com/spreadsheets/d/SHEETS_ID/'.
    Output is a string.
    """
    inputs = ["text"]
    outputs = ["text"]

    def fetch_data(self, source_link: str) -> str:
        # Set up the credentials
        scope = ["https://www.googleapis.com/auth/spreadsheets.readonly"]
        creds = Credentials.from_authorized_user_file('token.json', scope)
        service = build('sheets', 'v4', credentials=creds)

        # Open the spreadsheet and get all values
        source_link = source_link.rstrip("/")
        spreadsheet_id = source_link.split('/')[-1]
        range_name = 'Sheet1!A1:Z1000'
        sheet = service.spreadsheets()
        result = sheet.values().get(spreadsheetId=spreadsheet_id,
                                    range=range_name).execute()
        data = result.get('values', [])
        
        # Convert the data to a string representation
        raw_data = '\n'.join([','.join(row) for row in data])
        print(raw_data)
        return raw_data