--- license: apache-2.0 pipeline_tag: text-generation language: - en tags: - finance --- # THaLLE: Text Hyperlocally Augmented Large Language Extension **❗NOTICE❗**: `KBTG-Labs/THaLLE-0.1-7B-fa` is a WIP model checkpoint distributed for reproducing results in our [Technical Report](https://arxiv.org/abs/2406.07505). ## Training details This model is a [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) fine-tuned on our Internal CFA Mock Exam 2009-2019 containing 9,426 Questions using LoRA. ### Vocab Config Patching Prior to training, we patched Qwen/Qwen2-7B-Instruct's `tokenizer_config.json` `bos_token` field from `null` to the start token `"<|im_start|>"`. ```json { ... "bos_token": "<|im_start|>" ... } ``` ## Results For more details see our [Technical Report](https://arxiv.org/abs/2406.07505). | Model | Internal 2020 | Internal 2024 | Flare CFA* | | --------------------------------------- | ------------- | ------------- | ---------- | | APIs | | | | | `gpt-3.5-turbo-0125` | 0.5458 | 0.5027 | 0.6366 | | `gemini-1.5-flash-001` | 0.6271 | 0.6278 | 0.7355 | | `gemini-1.5-pro-001` | 0.6780 | 0.6444 | 0.7829 | | `gpt-4o-2024-05-13` | **0.8000** | **0.8055** | **0.8789** | | HF models | | | | | `"meta-llama/Llama-2-7b-chat-hf"` | 0.3774 | 0.3639 | 0.4264 | | `"google/gemma-7b-it"` | 0.5107 | 0.5333 | 0.6027 | | `"meta-llama/Meta-Llama-3-8B-Instruct"` | 0.5424 | 0.5222 | 0.6386 | | `"Qwen/Qwen2-7B-Instruct"` | 0.5740 | 0.5583 | 0.6831 | | `"KBTG-Labs/THaLLE-0.1-7B-fa"` | **0.6678** | **0.6500** | **0.7171** | [*] Flare CFA is `"ChanceFocus/flare-cfa"` ## Usage ### Requirements Since `KBTG-Labs/THaLLE-0.1-7B-fa` is a fine-tuned of Qwen2-7B-Instruct you will need to install `transformers>=4.37.0`. ### Reproducing results Running the script below should give you this output: ``` Progress: 1032/1032 | Correct: 740 (71.71%) ``` ```python import re from typing import Literal, Optional import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID: str = "KBTG-Labs/THaLLE-0.1-7B-fa" SYSTEM_PROMPT: str = """You are a CFA (chartered financial analyst) taking a test to evaluate your knowledge of finance. You will be given a question along with three possible answers (A, B, and C). Indicate the correct answer (A, B, or C).""" QUESTION_TEMPLATE: str = """Question: {question} A. {choice_a} B. {choice_b} C. {choice_c}""" def format_flare_cfa(text: str) -> dict[str, str]: text = re.sub(r"\s+", " ", text) pattern = r"Q:\s*(.*?),\s*CHOICES:\s*A:\s*(.*?),\s*B:\s*(.*?),\s*C:\s*(.*)" match = re.search(pattern, text) if match: question, choice_a, choice_b, choice_c = match.groups() return { "question": question.strip(), "choice_a": choice_a.strip(), "choice_b": choice_b.strip(), "choice_c": choice_c.strip(), } else: raise ValueError("Input text does not match the expected format.") def load_benchmark_dataset() -> list[dict[str, str]]: dataset = load_dataset("ChanceFocus/flare-cfa")["test"] prepared_dataset = [] for d in dataset: entry = format_flare_cfa(d["text"]) entry["answer"] = str(d["answer"]).upper() prepared_dataset.append(entry) return prepared_dataset def extract_choice( response_text: str, choice_a: str, choice_b: str, choice_c: str ) -> Optional[Literal["A", "B", "C"]]: def clean(text: str) -> str: return text.replace("–", "-").strip().replace("\n", "") find_choice = re.findall( r"([T|t]he correct answer is[.|:]? [ABC]|[A|a]nswer[.|:]?[is]?\W+?\n?[ABC]\s)", response_text, ) if find_choice: return clean(find_choice[0])[-1] if len(response_text) == 1 and response_text in "ABC": return response_text find_choice = re.findall(r"[ABC][.]\s?", response_text) if find_choice: return find_choice[0][0] choice = {"A": choice_a, "B": choice_b, "C": choice_c} for ch, content in choice.items(): if clean(content) in clean(response_text): return ch return None def inference(messages: list[dict[str, str]], model, tokenizer) -> str: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) generated_ids = model.generate( model_inputs.input_ids, max_new_tokens=768, do_sample=False, temperature=None, top_p=None, top_k=None, ) generated_ids = [ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return response def run_benchmark(dataset: list[dict[str, str]], model, tokenizer): total_correct = 0 for i, problem in enumerate(dataset, start=1): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": QUESTION_TEMPLATE.format(**problem)}, ] output_text = inference(messages, model, tokenizer) prediction = extract_choice( output_text, problem["choice_a"], problem["choice_b"], problem["choice_c"], ) correct = problem["answer"] == prediction total_correct += correct percent = total_correct / i * 100 print( f"Progress: {i}/{len(dataset)} | Correct: {total_correct} ({percent:.2f}%)", end="\r", ) if __name__ == "__main__": dataset = load_benchmark_dataset() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", ) run_benchmark(dataset, model, tokenizer) ``` ## Citation If you find our work useful, please cite: ``` @misc{labs2024thalle, title={THaLLE: Text Hyperlocally Augmented Large Language Extension -- Technical Report}, author={KBTG Labs and Danupat Khamnuansin and Atthakorn Petchsod and Anuruth Lertpiya and Pornchanan Balee and Thanawat Lodkaew and Tawunrat Chalothorn and Thadpong Pongthawornkamol and Monchai Lertsutthiwong}, year={2024}, eprint={2406.07505}, archivePrefix={arXiv}, primaryClass={cs.CL} } ```