Add interactive mode (#61)
* Initial work on interactive mode. * Improve interactive mode. Make rev. prompt optional. * Update README to explain interactive mode. * Fix OS X build
This commit is contained in:
		
							parent
							
								
									9661954835
								
							
						
					
					
						commit
						96ea727f47
					
				
							
								
								
									
										23
									
								
								README.md
								
								
								
								
							
							
						
						
									
										23
									
								
								README.md
								
								
								
								
							|  | @ -183,6 +183,29 @@ The number of files generated for each model is as follows: | |||
| 
 | ||||
| When running the larger models, make sure you have enough disk space to store all the intermediate files. | ||||
| 
 | ||||
| ### Interactive mode | ||||
| 
 | ||||
| If you want a more ChatGPT-like experience, you can run in interactive mode by passing `-i` as a parameter. | ||||
| In this mode, you can always interrupt generation by pressing Ctrl+C and enter one or more lines of text which will be converted into tokens and appended to the current context. You can also specify a *reverse prompt* with the parameter `-r "reverse prompt string"`. This will result in user input being prompted whenever the exact tokens of the reverse prompt string are encountered in the generation. A typical use is to use a prompt which makes LLaMa emulate a chat between multiple users, say Alice and Bob, and pass `-r "Alice:"`. | ||||
| 
 | ||||
| Here is an example few-shot interaction, invoked with the command | ||||
| ``` | ||||
| ./main -m ./models/13B/ggml-model-q4_0.bin -t 8 --repeat_penalty 1.2 --temp 0.9 --top_p 0.9 -n 256 \ | ||||
|                                            --color -i -r "User:" \ | ||||
|                                            -p \ | ||||
| "Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. | ||||
| 
 | ||||
| User: Hello, Bob. | ||||
| Bob: Hello. How may I help you today? | ||||
| User: Please tell me the largest city in Europe. | ||||
| Bob: Sure. The largest city in Europe is London, the capital of the United Kingdom. | ||||
| User:" | ||||
| ``` | ||||
| Note the use of `--color` to distinguish between user input and generated text. | ||||
| 
 | ||||
|  | ||||
| 
 | ||||
| 
 | ||||
| ## Limitations | ||||
| 
 | ||||
| - Not sure if my tokenizer is correct. There are a few places where we might have a mistake: | ||||
|  |  | |||
							
								
								
									
										137
									
								
								main.cpp
								
								
								
								
							
							
						
						
									
										137
									
								
								main.cpp
								
								
								
								
							|  | @ -11,6 +11,18 @@ | |||
| #include <string> | ||||
| #include <vector> | ||||
| 
 | ||||
| #include <signal.h> | ||||
| #include <unistd.h> | ||||
| 
 | ||||
| #define ANSI_COLOR_RED     "\x1b[31m" | ||||
| #define ANSI_COLOR_GREEN   "\x1b[32m" | ||||
| #define ANSI_COLOR_YELLOW  "\x1b[33m" | ||||
| #define ANSI_COLOR_BLUE    "\x1b[34m" | ||||
| #define ANSI_COLOR_MAGENTA "\x1b[35m" | ||||
| #define ANSI_COLOR_CYAN    "\x1b[36m" | ||||
| #define ANSI_COLOR_RESET   "\x1b[0m" | ||||
| #define ANSI_BOLD          "\x1b[1m" | ||||
| 
 | ||||
| // determine number of model parts based on the dimension
 | ||||
| static const std::map<int, int> LLAMA_N_PARTS = { | ||||
|     { 4096, 1 }, | ||||
|  | @ -733,6 +745,18 @@ bool llama_eval( | |||
|     return true; | ||||
| } | ||||
| 
 | ||||
| static bool is_interacting = false; | ||||
| 
 | ||||
| void sigint_handler(int signo) { | ||||
|     if (signo == SIGINT) { | ||||
|         if (!is_interacting) { | ||||
|             is_interacting=true; | ||||
|         } else { | ||||
|             _exit(130); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| int main(int argc, char ** argv) { | ||||
|     ggml_time_init(); | ||||
|     const int64_t t_main_start_us = ggml_time_us(); | ||||
|  | @ -787,6 +811,9 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|     params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); | ||||
| 
 | ||||
|     // tokenize the reverse prompt
 | ||||
|     std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); | ||||
| 
 | ||||
|     printf("\n"); | ||||
|     printf("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); | ||||
|     printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); | ||||
|  | @ -794,6 +821,24 @@ int main(int argc, char ** argv) { | |||
|         printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); | ||||
|     } | ||||
|     printf("\n"); | ||||
|     if (params.interactive) { | ||||
|         struct sigaction sigint_action; | ||||
|         sigint_action.sa_handler = sigint_handler; | ||||
|         sigemptyset (&sigint_action.sa_mask); | ||||
|         sigint_action.sa_flags = 0;  | ||||
|         sigaction(SIGINT, &sigint_action, NULL); | ||||
| 
 | ||||
|         printf("%s: interactive mode on.\n", __func__); | ||||
| 
 | ||||
|         if(antiprompt_inp.size()) { | ||||
|             printf("%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); | ||||
|             printf("%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); | ||||
|             for (int i = 0; i < (int) antiprompt_inp.size(); i++) { | ||||
|                 printf("%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); | ||||
|             } | ||||
|             printf("\n"); | ||||
|         } | ||||
|     } | ||||
|     printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); | ||||
|     printf("\n\n"); | ||||
| 
 | ||||
|  | @ -807,7 +852,28 @@ int main(int argc, char ** argv) { | |||
|     std::vector<gpt_vocab::id> last_n_tokens(last_n_size); | ||||
|     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); | ||||
| 
 | ||||
|     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { | ||||
| 
 | ||||
|     if (params.interactive) { | ||||
|         printf("== Running in interactive mode. ==\n" | ||||
|                " - Press Ctrl+C to interject at any time.\n" | ||||
|                " - Press Return to return control to LLaMa.\n" | ||||
|                " - If you want to submit another line, end your input in '\\'.\n"); | ||||
|     } | ||||
| 
 | ||||
|     int remaining_tokens = params.n_predict; | ||||
|     int input_consumed = 0; | ||||
|     bool input_noecho = false; | ||||
| 
 | ||||
|     // prompt user immediately after the starting prompt has been loaded
 | ||||
|     if (params.interactive_start) { | ||||
|         is_interacting = true; | ||||
|     } | ||||
| 
 | ||||
|     if (params.use_color) { | ||||
|         printf(ANSI_COLOR_YELLOW); | ||||
|     } | ||||
| 
 | ||||
|     while (remaining_tokens > 0) { | ||||
|         // predict
 | ||||
|         if (embd.size() > 0) { | ||||
|             const int64_t t_start_us = ggml_time_us(); | ||||
|  | @ -823,8 +889,8 @@ int main(int argc, char ** argv) { | |||
|         n_past += embd.size(); | ||||
|         embd.clear(); | ||||
| 
 | ||||
|         if (i >= embd_inp.size()) { | ||||
|             // sample next token
 | ||||
|         if (embd_inp.size() <= input_consumed) { | ||||
|             // out of input, sample next token
 | ||||
|             const float top_k = params.top_k; | ||||
|             const float top_p = params.top_p; | ||||
|             const float temp  = params.temp; | ||||
|  | @ -847,24 +913,74 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|             // add it to the context
 | ||||
|             embd.push_back(id); | ||||
| 
 | ||||
|             // echo this to console
 | ||||
|             input_noecho = false; | ||||
| 
 | ||||
|             // decrement remaining sampling budget
 | ||||
|             --remaining_tokens; | ||||
|         } else { | ||||
|             // if here, it means we are still processing the input prompt
 | ||||
|             for (int k = i; k < embd_inp.size(); k++) { | ||||
|                 embd.push_back(embd_inp[k]); | ||||
|             while (embd_inp.size() > input_consumed) { | ||||
|                 embd.push_back(embd_inp[input_consumed]); | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(embd_inp[k]); | ||||
|                 last_n_tokens.push_back(embd_inp[input_consumed]); | ||||
|                 ++input_consumed; | ||||
|                 if (embd.size() > params.n_batch) { | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             i += embd.size() - 1; | ||||
| 
 | ||||
|             if (params.use_color && embd_inp.size() <= input_consumed) { | ||||
|                 printf(ANSI_COLOR_RESET); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // display text
 | ||||
|         for (auto id : embd) { | ||||
|             printf("%s", vocab.id_to_token[id].c_str()); | ||||
|         if (!input_noecho) { | ||||
|             for (auto id : embd) { | ||||
|                 printf("%s", vocab.id_to_token[id].c_str()); | ||||
|             } | ||||
|             fflush(stdout); | ||||
|         } | ||||
| 
 | ||||
|         // in interactive mode, and not currently processing queued inputs;
 | ||||
|         // check if we should prompt the user for more
 | ||||
|         if (params.interactive && embd_inp.size() <= input_consumed) { | ||||
|             // check for reverse prompt
 | ||||
|             if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { | ||||
|                 // reverse prompt found
 | ||||
|                 is_interacting = true; | ||||
|             } | ||||
|             if (is_interacting) { | ||||
|                 // currently being interactive 
 | ||||
|                 bool another_line=true; | ||||
|                 while (another_line) { | ||||
|                     char buf[256] = {0}; | ||||
|                     int n_read; | ||||
|                     if(params.use_color) printf(ANSI_BOLD ANSI_COLOR_GREEN); | ||||
|                     scanf("%255[^\n]%n%*c", buf, &n_read); | ||||
|                     if(params.use_color) printf(ANSI_COLOR_RESET); | ||||
| 
 | ||||
|                     if (n_read > 0 && buf[n_read-1]=='\\') { | ||||
|                         another_line = true; | ||||
|                         buf[n_read-1] = '\n'; | ||||
|                         buf[n_read] = 0; | ||||
|                     } else { | ||||
|                         another_line = false; | ||||
|                         buf[n_read] = '\n'; | ||||
|                         buf[n_read+1] = 0; | ||||
|                     } | ||||
| 
 | ||||
|                     std::vector<gpt_vocab::id> line_inp = ::llama_tokenize(vocab, buf, false); | ||||
|                     embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); | ||||
| 
 | ||||
|                     input_noecho = true; // do not echo this again
 | ||||
|                 } | ||||
| 
 | ||||
|                 is_interacting = false;             | ||||
|             } | ||||
|         } | ||||
|         fflush(stdout); | ||||
| 
 | ||||
|         // end of text token
 | ||||
|         if (embd.back() == 2) { | ||||
|  | @ -873,6 +989,7 @@ int main(int argc, char ** argv) { | |||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     // report timing
 | ||||
|     { | ||||
|         const int64_t t_main_end_us = ggml_time_us(); | ||||
|  |  | |||
							
								
								
									
										14
									
								
								utils.cpp
								
								
								
								
							
							
						
						
									
										14
									
								
								utils.cpp
								
								
								
								
							|  | @ -49,6 +49,15 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|             params.n_batch = std::stoi(argv[++i]); | ||||
|         } else if (arg == "-m" || arg == "--model") { | ||||
|             params.model = argv[++i]; | ||||
|         } else if (arg == "-i" || arg == "--interactive") { | ||||
|             params.interactive = true; | ||||
|         } else if (arg == "--interactive-start") { | ||||
|             params.interactive = true; | ||||
|             params.interactive_start = true; | ||||
|         } else if (arg == "--color") { | ||||
|             params.use_color = true; | ||||
|         } else if (arg == "-r" || arg == "--reverse-prompt") { | ||||
|             params.antiprompt = argv[++i]; | ||||
|         } else if (arg == "-h" || arg == "--help") { | ||||
|             gpt_print_usage(argc, argv, params); | ||||
|             exit(0); | ||||
|  | @ -67,6 +76,11 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { | |||
|     fprintf(stderr, "\n"); | ||||
|     fprintf(stderr, "options:\n"); | ||||
|     fprintf(stderr, "  -h, --help            show this help message and exit\n"); | ||||
|     fprintf(stderr, "  -i, --interactive     run in interactive mode\n"); | ||||
|     fprintf(stderr, "  --interactive-start   run in interactive mode and poll user input at startup\n"); | ||||
|     fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n"); | ||||
|     fprintf(stderr, "                        in interactive mode, poll user input upon seeing PROMPT\n"); | ||||
|     fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n"); | ||||
|     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n"); | ||||
|     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads); | ||||
|     fprintf(stderr, "  -p PROMPT, --prompt PROMPT\n"); | ||||
|  |  | |||
							
								
								
									
										6
									
								
								utils.h
								
								
								
								
							
							
						
						
									
										6
									
								
								utils.h
								
								
								
								
							|  | @ -28,6 +28,12 @@ struct gpt_params { | |||
| 
 | ||||
|     std::string model = "models/lamma-7B/ggml-model.bin"; // model path
 | ||||
|     std::string prompt; | ||||
| 
 | ||||
|     bool use_color = false; // use color to distinguish generations and inputs
 | ||||
| 
 | ||||
|     bool interactive = false; // interactive mode
 | ||||
|     bool interactive_start = false; // reverse prompt immediately
 | ||||
|     std::string antiprompt = ""; // string upon seeing which more user input is prompted
 | ||||
| }; | ||||
| 
 | ||||
| bool gpt_params_parse(int argc, char ** argv, gpt_params & params); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue