Command line args bounds checking (#424)
* command line args bounds checking * unknown and invalid param exit codes 0 -> 1
This commit is contained in:
		
							parent
							
								
									a18c19259a
								
							
						
					
					
						commit
						ea10d3ded2
					
				
							
								
								
									
										101
									
								
								utils.cpp
								
								
								
								
							
							
						
						
									
										101
									
								
								utils.cpp
								
								
								
								
							|  | @ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|         params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency()); | ||||
|     } | ||||
| 
 | ||||
|     bool invalid_param = false; | ||||
|     std::string arg; | ||||
|     for (int i = 1; i < argc; i++) { | ||||
|         std::string arg = argv[i]; | ||||
|         arg = argv[i]; | ||||
| 
 | ||||
|         if (arg == "-s" || arg == "--seed") { | ||||
|             params.seed = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.seed = std::stoi(argv[i]); | ||||
|         } else if (arg == "-t" || arg == "--threads") { | ||||
|             params.n_threads = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.n_threads = std::stoi(argv[i]); | ||||
|         } else if (arg == "-p" || arg == "--prompt") { | ||||
|             params.prompt = argv[++i]; | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.prompt = argv[i]; | ||||
|         } else if (arg == "-f" || arg == "--file") { | ||||
|             std::ifstream file(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             std::ifstream file(argv[i]); | ||||
|             std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt)); | ||||
|             if (params.prompt.back() == '\n') { | ||||
|                 params.prompt.pop_back(); | ||||
|             } | ||||
|         } else if (arg == "-n" || arg == "--n_predict") { | ||||
|             params.n_predict = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.n_predict = std::stoi(argv[i]); | ||||
|         } else if (arg == "--top_k") { | ||||
|             params.top_k = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.top_k = std::stoi(argv[i]); | ||||
|         } else if (arg == "-c" || arg == "--ctx_size") { | ||||
|             params.n_ctx = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.n_ctx = std::stoi(argv[i]); | ||||
|         } else if (arg == "--memory_f16") { | ||||
|             params.memory_f16 = true; | ||||
|         } else if (arg == "--top_p") { | ||||
|             params.top_p = std::stof(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.top_p = std::stof(argv[i]); | ||||
|         } else if (arg == "--temp") { | ||||
|             params.temp = std::stof(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.temp = std::stof(argv[i]); | ||||
|         } else if (arg == "--repeat_last_n") { | ||||
|             params.repeat_last_n = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.repeat_last_n = std::stoi(argv[i]); | ||||
|         } else if (arg == "--repeat_penalty") { | ||||
|             params.repeat_penalty = std::stof(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.repeat_penalty = std::stof(argv[i]); | ||||
|         } else if (arg == "-b" || arg == "--batch_size") { | ||||
|             params.n_batch = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.n_batch = std::stoi(argv[i]); | ||||
|         } else if (arg == "-m" || arg == "--model") { | ||||
|             params.model = argv[++i]; | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.model = argv[i]; | ||||
|         } else if (arg == "-i" || arg == "--interactive") { | ||||
|             params.interactive = true; | ||||
|         } else if (arg == "--interactive-first") { | ||||
|  | @ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|         } else if (arg == "--color") { | ||||
|             params.use_color = true; | ||||
|         } else if (arg == "-r" || arg == "--reverse-prompt") { | ||||
|             params.antiprompt.push_back(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.antiprompt.push_back(argv[i]); | ||||
|         } else if (arg == "--perplexity") { | ||||
|             params.perplexity = true; | ||||
|         } else if (arg == "--ignore-eos") { | ||||
|             params.ignore_eos = true; | ||||
|         } else if (arg == "--n_parts") { | ||||
|             params.n_parts = std::stoi(argv[++i]); | ||||
|             if (++i >= argc) { | ||||
|                 invalid_param = true; | ||||
|                 break; | ||||
|             } | ||||
|             params.n_parts = std::stoi(argv[i]); | ||||
|         } else if (arg == "-h" || arg == "--help") { | ||||
|             gpt_print_usage(argc, argv, params); | ||||
|             exit(0); | ||||
|  | @ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||
|         } else { | ||||
|             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); | ||||
|             gpt_print_usage(argc, argv, params); | ||||
|             exit(0); | ||||
|             exit(1); | ||||
|         } | ||||
|     } | ||||
|     if (invalid_param) { | ||||
|         fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); | ||||
|         gpt_print_usage(argc, argv, params); | ||||
|         exit(1); | ||||
|     } | ||||
| 
 | ||||
|     return true; | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue