using UnityEngine; using Microsoft.ML.Tokenizers; using Unity.Sentis; using System.IO; using System.Linq; using System.Collections.Generic; using System.Collections; public class Phi3Claude : MonoBehaviour { IWorker worker; LlamaTokenizer tokenizer; List tokens = new(); TensorInt inputTensor, attentionMaskTensor, positionIdsTensor; TensorFloat outputLogits; int maxTokens = 100; // Maximum number of tokens to generate List eosTokens; // End of sequence tokens private IBackend backend; private void Start() { var tokenizerModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/tokenizer.model"); var sentisModelPath = Path.Combine(Application.streamingAssetsPath, "Phi35/model_Uint8.sentis"); var configPath = Path.Combine(Application.streamingAssetsPath, "Phi35/generation_config.json"); var model = ModelLoader.Load(sentisModelPath); worker = WorkerFactory.CreateWorker(BackendType.GPUCompute, model); Dictionary specialTokens = TokenizerUtils.LoadSpecialTokens(Path.Combine(Application.streamingAssetsPath, "Phi35/added_tokens.json")); using (Stream tokenizerModelStream = new FileStream(tokenizerModelPath, FileMode.Open, FileAccess.Read)) { tokenizer = LlamaTokenizer.Create( tokenizerModelStream, addBeginOfSentence: true, addEndOfSentence: false, specialTokens: specialTokens ); } eosTokens = TokenizerUtils.IdentifyEOSTokens(configPath); backend = WorkerFactory.CreateBackend(BackendType.GPUCompute); Generate("Hello, how is your day?"); } public void Generate(string userPrompt, string systemPrompt = "You are a helpful assistant.") { string completePrompt = Phi3InputFormatter.FormatChatInput(systemPrompt, userPrompt); Debug.Log("Complete prompt : " + completePrompt); int[] inputIds = tokenizer.EncodeToIds(completePrompt).ToArray(); Debug.Log($"Tokenized input: [{string.Join(", ", inputIds)}]"); Debug.Log($"Decoded tokens: [{string.Join(", ", tokenizer.Decode(inputIds, true))}]"); tokens.Clear(); tokens.AddRange(inputIds); StartCoroutine(GenerateSequence()); } private IEnumerator GenerateSequence() { for (int i = 0; i < maxTokens; i++) { RefreshTensors(tokens.ToArray()); worker.Execute(new Dictionary() { {"input_ids", inputTensor}, {"attention_mask", attentionMaskTensor}, {"position_ids", positionIdsTensor} }); // > 15ms (/!\ should be async) outputLogits = worker.PeekOutput("logits") as TensorFloat; // Async outputLogits.ReadbackRequest(); // Async yield return outputLogits.IsReadbackRequestDone(); // 236 ms tokens.Add(ProcessLogits()); // > 200ms int nextToken = tokens[tokens.Count - 1]; CleanupTensors(); if (eosTokens.Contains(nextToken)) break; } string generatedText = tokenizer.Decode(tokens.ToArray(), true); // 0 ms Debug.Log($"Generated sequence: {generatedText}"); } private int ProcessLogits() { // Greedy sampling for simplicity using var argMaxTensor = TensorInt.AllocNoData(new TensorShape(1, outputLogits.shape[1])); backend.ArgMax(outputLogits, argMaxTensor, axis: 2, selectLastIndex: false); var argMaxTensorArray = argMaxTensor.ToReadOnlyArray(); // TODO : investigate on why it's long to process int nextToken = argMaxTensorArray[outputLogits.shape[1] - 1]; Debug.Log($"Next token: [ID = {nextToken}, STR = \"{tokenizer.Decode(new[] { nextToken }, true)}\"]"); return nextToken; } private void RefreshTensors(int[] ids) { // Update input tensors with the full context inputTensor = new TensorInt(new TensorShape(1, ids.Length), ids); attentionMaskTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Repeat(1, ids.Length).ToArray()); positionIdsTensor = new TensorInt(new TensorShape(1, ids.Length), Enumerable.Range(0, ids.Length).ToArray()); } private void CleanupTensors() { inputTensor?.Dispose(); attentionMaskTensor?.Dispose(); positionIdsTensor?.Dispose(); outputLogits?.Dispose(); } private void OnDestroy() { CleanupTensors(); worker?.Dispose(); backend?.Dispose(); } }