CodeVulnerabilityAI / finetune.py
Red-tech-hub
[update] new vectores
19353ca
raw
history blame contribute delete
No virus
3.78 kB
import uuid
import chromadb
import pandas as pd
import os
from dotenv import load_dotenv
import json
from transformers import AutoModelForCausalLM
load_dotenv()
# ollama_ef = AutoModelForCausalLM.from_pretrained("nomic-embed-text-v1.5.Q5_K_S.gguf",
# model_type='llama',
# max_new_tokens = 10960,
# threads = 3,
# )
csv_files = []
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
cve_csv_path = os.path.join(root_dir, 'codevulnerabilityai\\data\\cve')
csv_files.extend([os.path.join(cve_csv_path, f) for f in os.listdir(cve_csv_path) if f.endswith('.csv')])
dtype_dict = {
'Name': str,
'Status': str,
'Description': str,
'References': str,
'Phase': str,
'Votes': str,
'Comments': str
}
chroma_data_path = str(os.getenv('CHROMA_DATA_PATH'))
chroma_db_directory = str("chroma_db/")
client = chromadb.PersistentClient(path=os.path.join(chroma_data_path, chroma_db_directory))
collection = client.get_or_create_collection(name="CVE")
documents_to_add = []
ids_to_add = []
metadata_to_add = []
documents_to_add_string = []
batch_size = 10
current_batch = 0
if csv_files:
for csv_file in csv_files:
print(f"Processing {csv_file}...")
df = pd.read_csv(csv_file, on_bad_lines='skip', dtype=dtype_dict)
documents = df['Description'].fillna('').astype(str).tolist()
if not df.empty and 'Description' in df.columns:
for index, row in df.iterrows():
metadata_parts = row['Name'].split(';')
metadata = {
"Name": str(metadata_parts[0].strip()),
"Status": str(metadata_parts[1].strip()) if len(metadata_parts) > 1 else "",
"Description": str(metadata_parts[2].strip()) if len(metadata_parts) > 2 else "",
"References": str(metadata_parts[3].strip()) if len(metadata_parts) > 3 else "",
"Phase": str(metadata_parts[4].strip()) if len(metadata_parts) > 4 else "",
"Votes": str(metadata_parts[5].strip()) if len(metadata_parts) > 5 else "",
}
document_id = str(uuid.uuid4())
document_content = metadata["Description"]
document = {'id': document_id, 'content': document_content}
documents_to_add.append(document)
documents_to_add_string.append(json.dumps(documents_to_add))
ids_to_add.append(document_id)
metadata_to_add.append(metadata)
current_batch += 1
if current_batch % batch_size == 0:
print(f"Batch {current_batch // batch_size} added to the collection.")
collection.add(documents=documents_to_add_string, ids=ids_to_add, metadatas=metadata_to_add)
documents_to_add = []
ids_to_add = []
metadata_to_add = []
documents_to_add_string = []
print(f"Batch {current_batch // batch_size} completed.")
else:
print(f"Skipping file {csv_file} due to empty DataFrame or missing 'Description' column")
else:
print("No CSV files found in the directory. Skipping processing.")
# Add the remaining documents if there are less than 100 left
if documents_to_add:
print(f"Adding remaining {len(documents_to_add)} documents to the collection.")
collection.add(documents=documents_to_add_string, ids=ids_to_add, metadatas=metadata_to_add)
# results = collection.query(
# query_texts=["Dotnet"],
# n_results=3,
# )
# print(results)