Support for multiple reverse prompts. (#299)
Co-authored-by: Johnman <> Co-authored-by: Johnman <tjohnman@github>
This commit is contained in:
		
							parent
							
								
									7392f1cd2c
								
							
						
					
					
						commit
						24568371ae
					
				
							
								
								
									
										34
									
								
								main.cpp
								
								
								
								
							
							
						
						
									
										34
									
								
								main.cpp
								
								
								
								
							|  | @ -855,14 +855,18 @@ int main(int argc, char ** argv) { | |||
|     // in instruct mode, we inject a prefix and a suffix to each input by the user
 | ||||
|     if (params.instruct) { | ||||
|         params.interactive = true; | ||||
|         params.antiprompt = "### Instruction:\n\n"; | ||||
|         params.antiprompt.push_back("### Instruction:\n\n"); | ||||
|     } | ||||
| 
 | ||||
|     // tokenize the reverse prompt
 | ||||
|     std::vector<gpt_vocab::id> antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); | ||||
|     std::vector<std::vector<gpt_vocab::id>> antipromptv_inp; | ||||
|      | ||||
|     for (auto antiprompt : params.antiprompt) { | ||||
|         antipromptv_inp.push_back(::llama_tokenize(vocab, antiprompt, false)); | ||||
|     } | ||||
| 
 | ||||
|     // enable interactive mode if reverse prompt is specified
 | ||||
|     if (!antiprompt_inp.empty()) { | ||||
|     if (!antipromptv_inp.size()) { | ||||
|         params.interactive = true; | ||||
|     } | ||||
| 
 | ||||
|  | @ -886,13 +890,16 @@ int main(int argc, char ** argv) { | |||
| 
 | ||||
|         fprintf(stderr, "%s: interactive mode on.\n", __func__); | ||||
| 
 | ||||
|         if (antiprompt_inp.size()) { | ||||
|             fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.c_str()); | ||||
|             fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); | ||||
|             for (int i = 0; i < (int) antiprompt_inp.size(); i++) { | ||||
|                 fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); | ||||
|         if(antipromptv_inp.size()) { | ||||
|             for (size_t apindex = 0; apindex < antipromptv_inp.size(); ++apindex) { | ||||
|                 auto antiprompt_inp = antipromptv_inp.at(apindex); | ||||
|                 fprintf(stderr, "%s: reverse prompt: '%s'\n", __func__, params.antiprompt.at(apindex).c_str()); | ||||
|                 fprintf(stderr, "%s: number of tokens in reverse prompt = %zu\n", __func__, antiprompt_inp.size()); | ||||
|                 for (int i = 0; i < (int) antiprompt_inp.size(); i++) { | ||||
|                     fprintf(stderr, "%6d -> '%s'\n", antiprompt_inp[i], vocab.id_to_token.at(antiprompt_inp[i]).c_str()); | ||||
|                 } | ||||
|                 fprintf(stderr, "\n"); | ||||
|             } | ||||
|             fprintf(stderr, "\n"); | ||||
|         } | ||||
|     } | ||||
|     fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); | ||||
|  | @ -1009,9 +1016,12 @@ int main(int argc, char ** argv) { | |||
|         // check if we should prompt the user for more
 | ||||
|         if (params.interactive && embd_inp.size() <= input_consumed) { | ||||
|             // check for reverse prompt
 | ||||
|             if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { | ||||
|                 // reverse prompt found
 | ||||
|                 is_interacting = true; | ||||
|             for (auto antiprompt_inp : antipromptv_inp) { | ||||
|                 if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { | ||||
|                     // reverse prompt found
 | ||||
|                     is_interacting = true; | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|             if (is_interacting) { | ||||
|                 if (params.instruct) { | ||||
|  |  | |||
|  | @ -70,7 +70,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|         } else if (arg == "--color") { | ||||
|             params.use_color = true; | ||||
|         } else if (arg == "-r" || arg == "--reverse-prompt") { | ||||
|             params.antiprompt = argv[++i]; | ||||
|             params.antiprompt.push_back(argv[++i]); | ||||
|         } else if (arg == "--ignore-eos") { | ||||
|             params.ignore_eos = true; | ||||
|         } else if (arg == "-h" || arg == "--help") { | ||||
|  | @ -96,7 +96,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||
|     fprintf(stderr, "  -i, --interactive     run in interactive mode\n"); | ||||
|     fprintf(stderr, "  -ins, --instruct      run in instruction mode (use with Alpaca models)\n"); | ||||
|     fprintf(stderr, "  -r PROMPT, --reverse-prompt PROMPT\n"); | ||||
|     fprintf(stderr, "                        in interactive mode, poll user input upon seeing PROMPT\n"); | ||||
|     fprintf(stderr, "                        in interactive mode, poll user input upon seeing PROMPT (can be\n"); | ||||
|     fprintf(stderr, "                        specified more than once for multiple prompts).\n"); | ||||
|     fprintf(stderr, "  --color               colorise output to distinguish prompt and user input from generations\n"); | ||||
|     fprintf(stderr, "  -s SEED, --seed SEED  RNG seed (default: -1)\n"); | ||||
|     fprintf(stderr, "  -t N, --threads N     number of threads to use during computation (default: %d)\n", params.n_threads); | ||||
|  |  | |||
							
								
								
									
										4
									
								
								utils.h
								
								
								
								
							
							
						
						
									
										4
									
								
								utils.h
								
								
								
								
							|  | @ -30,15 +30,15 @@ struct gpt_params { | |||
| 
 | ||||
|     std::string model      = "models/lamma-7B/ggml-model.bin"; // model path
 | ||||
|     std::string prompt     = ""; | ||||
|     std::string antiprompt = ""; // string upon seeing which more user input is prompted
 | ||||
| 
 | ||||
|     bool random_prompt = false; | ||||
| 
 | ||||
|     bool use_color = false; // use color to distinguish generations and inputs
 | ||||
| 
 | ||||
|     bool interactive = false; // interactive mode
 | ||||
|     bool interactive_start = false; // reverse prompt immediately
 | ||||
|     std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
 | ||||
|     bool instruct    = false; // instruction mode (used for Alpaca models)
 | ||||
|      | ||||
|     bool ignore_eos = false; // do not stop generating after eos
 | ||||
| }; | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue