diff --git a/main.cpp b/main.cpp
index a11d755..0155614 100644
--- a/main.cpp
+++ b/main.cpp
@@ -825,6 +825,7 @@ int main(int argc, char ** argv) {
 
         if (i >= embd_inp.size()) {
             // sample next token
+            const float top_k = params.top_k;
             const float top_p = params.top_p;
             const float temp  = params.temp;
             const float repeat_penalty = params.repeat_penalty;
@@ -836,7 +837,7 @@ int main(int argc, char ** argv) {
             {
                 const int64_t t_start_sample_us = ggml_time_us();
 
-                id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
+                id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
 
                 last_n_tokens.erase(last_n_tokens.begin());
                 last_n_tokens.push_back(id);
diff --git a/utils.cpp b/utils.cpp
index 58e7070..5435d47 100644
--- a/utils.cpp
+++ b/utils.cpp
@@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
     return true;
 }
 
-gpt_vocab::id gpt_sample_top_k_top_p(
-        const gpt_vocab & vocab,
-        const float * logits,
-        int    top_k,
-        double top_p,
-        double temp,
-        std::mt19937 & rng) {
-    int n_logits = vocab.id_to_token.size();
-
-    std::vector<std::pair<double, gpt_vocab::id>> logits_id;
-    logits_id.reserve(n_logits);
-
-    {
-        const double scale = 1.0/temp;
-        for (int i = 0; i < n_logits; ++i) {
-            logits_id.push_back(std::make_pair(logits[i]*scale, i));
-        }
-    }
 
+void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
     // find the top K tokens
     std::partial_sort(
             logits_id.begin(),
@@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
     });
 
     logits_id.resize(top_k);
-
-    double maxl = -INFINITY;
-    for (const auto & kv : logits_id) {
-        maxl = std::max(maxl, kv.first);
-    }
-
-    // compute probs for the top K tokens
-    std::vector<double> probs;
-    probs.reserve(logits_id.size());
-
-    double sum = 0.0;
-    for (const auto & kv : logits_id) {
-        double p = exp(kv.first - maxl);
-        probs.push_back(p);
-        sum += p;
-    }
-
-    // normalize the probs
-    for (auto & p : probs) {
-        p /= sum;
-    }
-
-    if (top_p < 1.0f) {
-        double cumsum = 0.0f;
-        for (int i = 0; i < top_k; i++) {
-            cumsum += probs[i];
-            if (cumsum >= top_p) {
-                top_k = i + 1;
-                probs.resize(top_k);
-                logits_id.resize(top_k);
-                break;
-            }
-        }
-
-        cumsum = 1.0/cumsum;
-        for (int i = 0; i < (int) probs.size(); i++) {
-            probs[i] *= cumsum;
-        }
-    }
-
-    //printf("\n");
-    //for (int i = 0; i < (int) probs.size(); i++) {
-    //    printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
-    //}
-    //exit(0);
-
-    std::discrete_distribution<> dist(probs.begin(), probs.end());
-    int idx = dist(rng);
-
-    return logits_id[idx].second;
 }
 
-gpt_vocab::id llama_sample_top_p(
+gpt_vocab::id llama_sample_top_p_top_k(
         const gpt_vocab & vocab,
         const float * logits,
         std::vector<gpt_vocab::id> & last_n_tokens,
         double repeat_penalty,
+        int top_k,
         double top_p,
         double temp,
         std::mt19937 & rng) {
@@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
         }
     }
 
-    std::sort(
-            logits_id.begin(),
-            logits_id.end(),
-            [](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
-        return a.first > b.first;
-    });
+    sample_top_k(logits_id, top_k);
 
     double maxl = -INFINITY;
     for (const auto & kv : logits_id) {
diff --git a/utils.h b/utils.h
index e331904..5b3d736 100644
--- a/utils.h
+++ b/utils.h
@@ -19,7 +19,7 @@ struct gpt_params {
     int32_t repeat_last_n = 64;  // last n tokens to penalize
 
     // sampling parameters
-    int32_t top_k = 40; // unused
+    int32_t top_k = 40;
     float   top_p = 0.95f;
     float   temp  = 0.80f;
     float   repeat_penalty  = 1.30f;
@@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
 //   - consider only the top K tokens
 //   - from them, consider only the top tokens with cumulative probability > P
 //
-// TODO: not sure if this implementation is correct
-// TODO: temperature is not implemented
-//
-gpt_vocab::id gpt_sample_top_k_top_p(
-        const gpt_vocab & vocab,
-        const float * logits,
-        int    top_k,
-        double top_p,
-        double temp,
-        std::mt19937 & rng);
-
-gpt_vocab::id llama_sample_top_p(
+gpt_vocab::id llama_sample_top_p_top_k(
         const gpt_vocab & vocab,
         const float * logits,
         std::vector<gpt_vocab::id> & last_n_tokens,
         double repeat_penalty,
+        int top_k,
         double top_p,
         double temp,
         std::mt19937 & rng);
 
+// filer to top K tokens from list of logits
+void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
+
 //
 // Quantization
 //