Add repetition penalty (#20)
* Adding repeat penalization * Update utils.h * Update utils.cpp * Numeric fix Should probably still scale by temp even if penalized * Update comments, more proper application I see that numbers can go negative so a fix from a referenced commit * Minor formatting --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							parent
							
								
									702fddf5c5
								
							
						
					
					
						commit
						129c7d1ea8
					
				
							
								
								
									
										14
									
								
								main.cpp
								
								
								
								
							
							
						
						
									
										14
									
								
								main.cpp
								
								
								
								
							|  | @ -792,7 +792,7 @@ int main(int argc, char ** argv) { | |||
|         printf("%6d -> '%s'\n", embd_inp[i], vocab.id_to_token.at(embd_inp[i]).c_str()); | ||||
|     } | ||||
|     printf("\n"); | ||||
|     printf("sampling parameters: temp = %f, top_k = %d, top_p = %f\n", params.temp, params.top_k, params.top_p); | ||||
|     printf("sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); | ||||
|     printf("\n\n"); | ||||
| 
 | ||||
|     std::vector<gpt_vocab::id> embd; | ||||
|  | @ -801,6 +801,10 @@ int main(int argc, char ** argv) { | |||
|     size_t mem_per_token = 0; | ||||
|     llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); | ||||
| 
 | ||||
|     int last_n_size = params.repeat_last_n; | ||||
|     std::vector<gpt_vocab::id> last_n_tokens(last_n_size); | ||||
|     std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); | ||||
| 
 | ||||
|     for (int i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { | ||||
|         // predict
 | ||||
|         if (embd.size() > 0) { | ||||
|  | @ -821,6 +825,7 @@ int main(int argc, char ** argv) { | |||
|             // sample next token
 | ||||
|             const float top_p = params.top_p; | ||||
|             const float temp  = params.temp; | ||||
|             const float repeat_penalty = params.repeat_penalty; | ||||
| 
 | ||||
|             const int n_vocab = model.hparams.n_vocab; | ||||
| 
 | ||||
|  | @ -829,7 +834,10 @@ int main(int argc, char ** argv) { | |||
|             { | ||||
|                 const int64_t t_start_sample_us = ggml_time_us(); | ||||
| 
 | ||||
|                 id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), top_p, temp, rng); | ||||
|                 id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng); | ||||
| 
 | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(id); | ||||
| 
 | ||||
|                 t_sample_us += ggml_time_us() - t_start_sample_us; | ||||
|             } | ||||
|  | @ -840,6 +848,8 @@ int main(int argc, char ** argv) { | |||
|             // if here, it means we are still processing the input prompt
 | ||||
|             for (int k = i; k < embd_inp.size(); k++) { | ||||
|                 embd.push_back(embd_inp[k]); | ||||
|                 last_n_tokens.erase(last_n_tokens.begin()); | ||||
|                 last_n_tokens.push_back(embd_inp[k]); | ||||
|                 if (embd.size() > params.n_batch) { | ||||
|                     break; | ||||
|                 } | ||||
|  |  | |||
							
								
								
									
										21
									
								
								utils.cpp
								
								
								
								
							
							
						
						
									
										21
									
								
								utils.cpp
								
								
								
								
							|  | @ -23,6 +23,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|             params.top_p = std::stof(argv[++i]); | ||||
|         } else if (arg == "--temp") { | ||||
|             params.temp = std::stof(argv[++i]); | ||||
|         } else if (arg == "--repeat_last_n") { | ||||
|             params.repeat_last_n = std::stoi(argv[++i]); | ||||
|         } else if (arg == "--repeat_penalty") { | ||||
|             params.repeat_penalty = std::stof(argv[++i]); | ||||
|         } else if (arg == "-b" || arg == "--batch_size") { | ||||
|             params.n_batch = std::stoi(argv[++i]); | ||||
|         } else if (arg == "-m" || arg == "--model") { | ||||
|  | @ -52,6 +56,8 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params) { | |||
|     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d)\n", params.n_predict); | ||||
|     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k); | ||||
|     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p); | ||||
|     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); | ||||
|     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); | ||||
|     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp); | ||||
|     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); | ||||
|     fprintf(stderr, "  -m FNAME, --model FNAME\n"); | ||||
|  | @ -372,6 +378,8 @@ gpt_vocab::id gpt_sample_top_k_top_p( | |||
| gpt_vocab::id llama_sample_top_p( | ||||
|         const gpt_vocab & vocab, | ||||
|         const float * logits, | ||||
|         std::vector<gpt_vocab::id> & last_n_tokens, | ||||
|         double repeat_penalty, | ||||
|         double top_p, | ||||
|         double temp, | ||||
|         std::mt19937 & rng) { | ||||
|  | @ -383,7 +391,18 @@ gpt_vocab::id llama_sample_top_p( | |||
|     { | ||||
|         const double scale = 1.0/temp; | ||||
|         for (int i = 0; i < n_logits; ++i) { | ||||
|             logits_id.push_back(std::make_pair(logits[i]*scale, i)); | ||||
|             // repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
 | ||||
|             // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
 | ||||
|             if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { | ||||
|                 // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
 | ||||
|                 if (logits[i] < 0.0) { | ||||
|                     logits_id.push_back(std::make_pair(logits[i]*scale*repeat_penalty, i)); | ||||
|                 } else { | ||||
|                     logits_id.push_back(std::make_pair(logits[i]*scale/repeat_penalty, i)); | ||||
|                 }                 | ||||
|             } else { | ||||
|                 logits_id.push_back(std::make_pair(logits[i]*scale, i)); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										4
									
								
								utils.h
								
								
								
								
							
							
						
						
									
										4
									
								
								utils.h
								
								
								
								
							|  | @ -16,11 +16,13 @@ struct gpt_params { | |||
|     int32_t seed      = -1; // RNG seed
 | ||||
|     int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); | ||||
|     int32_t n_predict = 128; // new tokens to predict
 | ||||
|     int32_t repeat_last_n = 64;  // last n tokens to penalize
 | ||||
| 
 | ||||
|     // sampling parameters
 | ||||
|     int32_t top_k = 40; // unused
 | ||||
|     float   top_p = 0.95f; | ||||
|     float   temp  = 0.80f; | ||||
|     float   repeat_penalty  = 1.30f; | ||||
| 
 | ||||
|     int32_t n_batch = 8; // batch size for prompt processing
 | ||||
| 
 | ||||
|  | @ -89,6 +91,8 @@ gpt_vocab::id gpt_sample_top_k_top_p( | |||
| gpt_vocab::id llama_sample_top_p( | ||||
|         const gpt_vocab & vocab, | ||||
|         const float * logits, | ||||
|         std::vector<gpt_vocab::id> & last_n_tokens, | ||||
|         double repeat_penalty, | ||||
|         double top_p, | ||||
|         double temp, | ||||
|         std::mt19937 & rng); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue