File size: 4,850 Bytes
3fab3ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
using UnityEngine;
using Microsoft.ML.Tokenizers;
using Unity.Sentis;
using System.IO;
using System.Linq;
using System.Collections.Generic;
using System.Collections;

public class Phi3Claude : MonoBehaviour
{
    IWorker worker;
    LlamaTokenizer tokenizer;

    List<int> tokens = new();
    TensorInt inputTensor, attentionMaskTensor, positionIdsTensor;
    TensorFloat outputLogits;

    int maxTokens = 100; // Maximum number of tokens to generate
    List<int> eosTokens; // End of sequence tokens

    private IBackend backend;

    private void Start()
    {
        var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model");
        var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis");
        var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json");
        
        var model = ModelLoader.Load(sentisModelPath);
        
        worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, model);
        Dictionary<string, int> specialTokens = TokenizerUtils.LoadSpecialTokens(Path.Combine(Application.streamingAssetsPath, "Phi35/added_tokens.json"));

        using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read))
        {
            tokenizer = LlamaTokenizer.Create(
                tokenizerModelStream,
                addBeginOfSentence: true,
                addEndOfSentence: false,
                specialTokens: specialTokens
            );
        }

        eosTokens = TokenizerUtils.IdentifyEOSTokens(configPath);
        backend = WorkerFactory.CreateBackend(BackendType.GPUCompute);

        Generate("Hello, how is your day?");
    }

    public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.") 
    {
        string completePrompt = Phi3InputFormatter.FormatChatInput(systemPrompt, userPrompt);
        Debug.Log("Complete prompt : " + completePrompt);

        int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray();
        Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]");
        Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]");

        tokens.Clear();
        tokens.AddRange(inputIds);

        StartCoroutine(GenerateSequence());
    }
    
    private IEnumerator GenerateSequence()
    {
        for (int i = 0; i < maxTokens; i++)
        {
            RefreshTensors(tokens.ToArray());

            worker.Execute(new Dictionary<string, Tensor>()
            {
                {"input_ids", inputTensor},
                {"attention_mask", attentionMaskTensor},
                {"position_ids", positionIdsTensor}
            }); // > 15ms (/!\ should be async)

            outputLogits = worker.PeekOutput("logits") as TensorFloat; // Async
            outputLogits.ReadbackRequest(); // Async

            yield return outputLogits.IsReadbackRequestDone(); // 236 ms

            tokens.Add(ProcessLogits()); // > 200ms

            int nextToken = tokens[tokens.Count - 1];

            CleanupTensors();

            if (eosTokens.Contains(nextToken))
                break;
        }

        string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms
        Debug.Log($"Generated sequence: {generatedText}");
    }


    private int ProcessLogits()
    {
        // Greedy sampling for simplicity
        using var argMaxTensor = TensorInt.AllocNoData(new TensorShape(1, outputLogits.shape[1]));
        backend.ArgMax(outputLogits, argMaxTensor, axis: 2, selectLastIndex: false);

        var argMaxTensorArray = argMaxTensor.ToReadOnlyArray(); // TODO : investigate on why it's long to process
        int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1];

        Debug.Log($"<color=orange>Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]</color>");

        return nextToken;
    }

    private void RefreshTensors(int[] ids) 
    {
        // Update input tensors with the full context
        inputTensor = new TensorInt(new TensorShape(1, ids.Length), ids);
        attentionMaskTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray());
        positionIdsTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray());
    }

    private void CleanupTensors()
    {
        inputTensor?.Dispose();
        attentionMaskTensor?.Dispose();
        positionIdsTensor?.Dispose();
        outputLogits?.Dispose();
    }

    private void OnDestroy() {
        CleanupTensors();
        
        worker?.Dispose();
        backend?.Dispose();
    }
}