From ea10d3ded2994106596ddf8e4ed02741b3e053e6 Mon Sep 17 00:00:00 2001
From: anzz1 <anzz1@live.com>
Date: Thu, 23 Mar 2023 19:54:28 +0200
Subject: [PATCH] Command line args bounds checking (#424)

* command line args bounds checking

* unknown and invalid param exit codes 0 -> 1
---
 utils.cpp | 101 +++++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 84 insertions(+), 17 deletions(-)

diff --git a/utils.cpp b/utils.cpp
index 1d5309c..45c9cab 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency());
     }
 
+    bool invalid_param = false;
+    std::string arg;
     for (int i = 1; i < argc; i++) {
-        std::string arg = argv[i];
+        arg = argv[i];
 
         if (arg == "-s" || arg == "--seed") {
-            params.seed = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.seed = std::stoi(argv[i]);
         } else if (arg == "-t" || arg == "--threads") {
-            params.n_threads = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_threads = std::stoi(argv[i]);
         } else if (arg == "-p" || arg == "--prompt") {
-            params.prompt = argv[++i];
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.prompt = argv[i];
         } else if (arg == "-f" || arg == "--file") {
-            std::ifstream file(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            std::ifstream file(argv[i]);
             std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
             if (params.prompt.back() == '\n') {
                 params.prompt.pop_back();
             }
         } else if (arg == "-n" || arg == "--n_predict") {
-            params.n_predict = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_predict = std::stoi(argv[i]);
         } else if (arg == "--top_k") {
-            params.top_k = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.top_k = std::stoi(argv[i]);
         } else if (arg == "-c" || arg == "--ctx_size") {
-            params.n_ctx = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_ctx = std::stoi(argv[i]);
         } else if (arg == "--memory_f16") {
             params.memory_f16 = true;
         } else if (arg == "--top_p") {
-            params.top_p = std::stof(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.top_p = std::stof(argv[i]);
         } else if (arg == "--temp") {
-            params.temp = std::stof(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.temp = std::stof(argv[i]);
         } else if (arg == "--repeat_last_n") {
-            params.repeat_last_n = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.repeat_last_n = std::stoi(argv[i]);
         } else if (arg == "--repeat_penalty") {
-            params.repeat_penalty = std::stof(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.repeat_penalty = std::stof(argv[i]);
         } else if (arg == "-b" || arg == "--batch_size") {
-            params.n_batch = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_batch = std::stoi(argv[i]);
         } else if (arg == "-m" || arg == "--model") {
-            params.model = argv[++i];
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.model = argv[i];
         } else if (arg == "-i" || arg == "--interactive") {
             params.interactive = true;
         } else if (arg == "--interactive-first") {
@@ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else if (arg == "--color") {
             params.use_color = true;
         } else if (arg == "-r" || arg == "--reverse-prompt") {
-            params.antiprompt.push_back(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.antiprompt.push_back(argv[i]);
         } else if (arg == "--perplexity") {
             params.perplexity = true;
         } else if (arg == "--ignore-eos") {
             params.ignore_eos = true;
         } else if (arg == "--n_parts") {
-            params.n_parts = std::stoi(argv[++i]);
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+            params.n_parts = std::stoi(argv[i]);
         } else if (arg == "-h" || arg == "--help") {
             gpt_print_usage(argc, argv, params);
             exit(0);
@@ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
         } else {
             fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
             gpt_print_usage(argc, argv, params);
-            exit(0);
+            exit(1);
         }
     }
+    if (invalid_param) {
+        fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
+        gpt_print_usage(argc, argv, params);
+        exit(1);
+    }
 
     return true;
 }