File size: 2,447 Bytes
f138a14
 
6f27821
 
f138a14
6f27821
 
 
 
 
f138a14
 
6f27821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f138a14
6f27821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f138a14
6f27821
 
 
 
 
 
 
 
 
 
 
 
 
f138a14
 
6f27821
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Creates unigram LM following KenLM
import math
import shutil, tempfile


def calculate_log_probabilities(word_counts, num_sentences, n_smoothing=0.01):
    """
    Calculate log probabilities for each word in the corpus,
    including a special <unk> token for unknown words.
    """
    total_words = sum(word_counts.values())
    total_words += 2 * num_sentences  # add counts for <s> and </s>
    # Adjust total for <unk>
    total_words_with_unk = total_words + 1  # Adding 1 for <unk>
    total_words_with_unk = total_words_with_unk + total_words_with_unk * n_smoothing

    # Calculate probabilities, adjust for <unk>
    probabilities = {
        word: ((count + n_smoothing) / total_words_with_unk)
        for word, count in word_counts.items()
    }
    probabilities["<unk>"] = 1 / total_words_with_unk
    probabilities["<s>"] = (num_sentences + n_smoothing) / total_words_with_unk
    probabilities["</s>"] = (num_sentences + n_smoothing) / total_words_with_unk

    # Convert to log probabilities
    return {word: math.log10(prob) for word, prob in probabilities.items()}


def maybe_generate_pseudo_bigram_arpa(arpa_fpath):
    with open(arpa_fpath, "r") as file:
        lines = file.readlines()

    # if ngram order >=2 , do not modify
    if any(["2-grams:" in l for l in lines]):
        return

    with open(arpa_fpath, "w") as file:
        for line in lines:
            if line.strip().startswith("ngram 1="):
                file.write(line)
                file.write("ngram 2=1\n")  # Add the new ngram line
                continue

            if line.strip() == "\\end\\":
                file.write("\\2-grams:\n")
                file.write("-9.9999999\t</s> <s>\n\n")

            file.write(line)


def save_log_probabilities(log_probabilities, file_path):
    with open(file_path, "w") as file:
        file.write(f"\data\\")
        file.write(f"\n")
        file.write(f"ngram 1={len(log_probabilities)}\n\n")
        file.write(f"\\1-grams:")
        file.write(f"\n")
        for word, log_prob in log_probabilities.items():
            if word == "<s>":
                log_prob = 0
            file.write(f"{log_prob}\t{word}\n")
        file.write(f"\n")
        file.write(f"\end\\")


def create_unigram_lm(word_counts, num_sentences, file_path, n_smoothing=0.01):
    log_probs = calculate_log_probabilities(word_counts, num_sentences, n_smoothing)
    save_log_probabilities(log_probs, file_path)