From 436e56193199a1625f8c561069f702e8840a9e08 Mon Sep 17 00:00:00 2001
From: Stephan Walter <stephan@walter.name>
Date: Tue, 28 Mar 2023 16:48:20 +0000
Subject: [PATCH] all : be more strict about converting float to double (#458)

* Be more strict about converting float to double

* Test equivalence of round, SILU implementations

Test module is commented out in CMakeLists.txt because the tests may
take a long time, depending on how much the compiler optimizes.

* Fix softmax in perplexity.cpp

* all : prefer float over double where appropriate

* perplexity : add <cmath>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
---
 CMakeLists.txt                     |   5 +-
 Makefile                           |   4 +
 examples/common.cpp                |   6 +-
 examples/main/main.cpp             |  11 +--
 examples/perplexity/perplexity.cpp |  20 +++--
 examples/quantize/quantize.cpp     |   4 +-
 ggml.c                             | 138 +++++++++++++++--------------
 llama.cpp                          |  52 +++++------
 llama.h                            |   8 +-
 tests/CMakeLists.txt               |   1 +
 tests/test-double-float.c          |  53 +++++++++++
 11 files changed, 185 insertions(+), 117 deletions(-)
 create mode 100644 tests/test-double-float.c

diff --git a/CMakeLists.txt b/CMakeLists.txt
index 27a222a..241be4c 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -124,17 +124,18 @@ if (LLAMA_ALL_WARNINGS)
             -Wall
             -Wextra
             -Wpedantic
-            -Wshadow
             -Wcast-qual
+            -Wdouble-promotion
+            -Wshadow
             -Wstrict-prototypes
             -Wpointer-arith
-            -Wno-unused-function
         )
         set(cxx_flags
             -Wall
             -Wextra
             -Wpedantic
             -Wcast-qual
+            -Wdouble-promotion
         )
     else()
         # todo : msvc
diff --git a/Makefile b/Makefile
index 973b295..9cfa89f 100644
--- a/Makefile
+++ b/Makefile
@@ -35,6 +35,10 @@ CFLAGS   = -I.              -O3 -DNDEBUG -std=c11   -fPIC
 CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
 LDFLAGS  =
 
+# warnings
+CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
+CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function
+
 # OS specific
 # TODO: support Windows
 ifeq ($(UNAME_S),Linux)
diff --git a/examples/common.cpp b/examples/common.cpp
index 880ebe9..af3ad9e 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -215,13 +215,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
     fprintf(stderr, "                        prompt file to start generation.\n");
     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k);
-    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p);
+    fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", (double)params.top_p);
     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n);
-    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty);
+    fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty);
     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx);
     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n");
     fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n");
-    fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp);
+    fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp);
     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n");
     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch);
     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n");
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index d5ab2cf..3130aef 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -209,7 +209,8 @@ int main(int argc, char ** argv) {
             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
         }
     }
-    fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
+    fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
+        params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
     fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
     fprintf(stderr, "\n\n");
 
@@ -274,10 +275,10 @@ int main(int argc, char ** argv) {
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
             // out of user input, sample next token
-            const float top_k          = params.top_k;
-            const float top_p          = params.top_p;
-            const float temp           = params.temp;
-            const float repeat_penalty = params.repeat_penalty;
+            const int32_t top_k          = params.top_k;
+            const float   top_p          = params.top_p;
+            const float   temp           = params.temp;
+            const float   repeat_penalty = params.repeat_penalty;
 
             llama_token id = 0;
 
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 75d526d..07ed0a8 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -1,15 +1,17 @@
 #include "common.h"
 #include "llama.h"
 
-std::vector<double> softmax(const std::vector<float>& logits) {
-    std::vector<double> probs(logits.size());
+#include <cmath>
+
+std::vector<float> softmax(const std::vector<float>& logits) {
+    std::vector<float> probs(logits.size());
     float max_logit = logits[0];
     for (float v : logits) max_logit = std::max(max_logit, v);
     double sum_exp = 0.0;
     for (size_t i = 0; i < logits.size(); i++) {
         // Subtract the maximum logit value from the current logit value for numerical stability
-        float logit = logits[i] - max_logit;
-        double exp_logit = std::exp(logit);
+        const float logit = logits[i] - max_logit;
+        const float exp_logit = expf(logit);
         sum_exp += exp_logit;
         probs[i] = exp_logit;
     }
@@ -24,14 +26,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
     auto tokens = ::llama_tokenize(ctx, params.prompt, true);
 
     int count = 0;
-    double nll = 0.0;
     int seq_count = tokens.size() / params.n_ctx;
 
+    double nll = 0.0;
+
     fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
 
     for (int i = 0; i < seq_count; ++i) {
         int start = i * params.n_ctx;
-        int end = start + params.n_ctx - 1;
+        int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512
+                                            //       it is better to always be power of 2 for better performance
         std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
         auto start_t = std::chrono::high_resolution_clock::now();
         if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
@@ -40,7 +44,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
         }
         auto end_t = std::chrono::high_resolution_clock::now();
         if (i == 0) {
-            double seconds = std::chrono::duration<double>(end_t - start_t).count();
+            const float seconds = std::chrono::duration<float>(end_t - start_t).count();
             printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0));
         }
         // We get the logits for all the tokens in the context window (params.n_ctx)
@@ -63,7 +67,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
             std::vector<float> tok_logits(
                 logits + j * n_vocab,
                 logits + (j + 1) * n_vocab);
-            double prob = softmax(tok_logits)[tokens[start + j + 1]];
+            const float prob = softmax(tok_logits)[tokens[start + j + 1]];
             nll += -std::log(prob);
             ++count;
         }
diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp
index 3888ff5..b444328 100644
--- a/examples/quantize/quantize.cpp
+++ b/examples/quantize/quantize.cpp
@@ -50,8 +50,8 @@ int main(int argc, char ** argv) {
         const int64_t t_main_end_us = ggml_time_us();
 
         printf("\n");
-        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f);
-        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f);
+        printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0);
+        printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0);
     }
 
     return 0;
diff --git a/ggml.c b/ggml.c
index bf8ec8a..83395a7 100644
--- a/ggml.c
+++ b/ggml.c
@@ -150,10 +150,10 @@ typedef double ggml_float;
 //
 #include <arm_neon.h>
 
-#define GGML_COMPUTE_FP16_TO_FP32(x) (x)
+#define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x))
 #define GGML_COMPUTE_FP32_TO_FP16(x) (x)
 
-#define GGML_FP16_TO_FP32(x) (x)
+#define GGML_FP16_TO_FP32(x) ((float) (x))
 #define GGML_FP32_TO_FP16(x) (x)
 
 #else
@@ -322,7 +322,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
 // note: do not use these inside ggml.c
 // these are meant to be used via the ggml.h API
 float ggml_fp16_to_fp32(ggml_fp16_t x) {
-    return GGML_FP16_TO_FP32(x);
+    return (float) GGML_FP16_TO_FP32(x);
 }
 
 ggml_fp16_t ggml_fp32_to_fp16(float x) {
@@ -488,8 +488,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
             const float v0 = x[i*QK + l + 0]*id;
             const float v1 = x[i*QK + l + 1]*id;
 
-            const uint8_t vi0 = ((int8_t) (round(v0))) + 8;
-            const uint8_t vi1 = ((int8_t) (round(v1))) + 8;
+            const uint8_t vi0 = (int8_t)roundf(v0) + 8;
+            const uint8_t vi1 = (int8_t)roundf(v1) + 8;
 
             assert(vi0 >= 0 && vi0 < 16);
             assert(vi1 >= 0 && vi1 < 16);
@@ -566,7 +566,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
                 MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
 
         const float d = amax / ((1 << 3) - 1);
-        const float id = d ? 1.0/d : 0.0;
+        const float id = d ? 1.0f/d : 0.0f;
 
         y[i].d = d;
 
@@ -716,8 +716,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
             const float v0 = (x[i*QK + l + 0] - min)*id;
             const float v1 = (x[i*QK + l + 1] - min)*id;
 
-            const uint8_t vi0 = round(v0);
-            const uint8_t vi1 = round(v1);
+            const uint8_t vi0 = roundf(v0);
+            const uint8_t vi1 = roundf(v1);
 
             assert(vi0 >= 0 && vi0 < 16);
             assert(vi1 >= 0 && vi1 < 16);
@@ -1001,7 +1001,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
         }                                                         \
         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \
         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \
-        res = vaddvq_f32(vaddq_f32(t0, t1));                      \
+        res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \
     }
 
     #define GGML_F16_VEC                GGML_F16x8
@@ -1437,9 +1437,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co
 inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   }
 
 inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
-    ggml_float sumf = 0.0;
-
 #ifdef GGML_SIMD
+    float sumf = 0.0f;
     const int np = (n & ~(GGML_F32_STEP - 1));
 
     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
@@ -1465,8 +1464,9 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     }
 #else
     // scalar
+    ggml_float sumf = 0.0;
     for (int i = 0; i < n; ++i) {
-        sumf += x[i]*y[i];
+        sumf += (ggml_float)(x[i]*y[i]);
     }
 #endif
 
@@ -1529,11 +1529,11 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
 
     // leftovers
     for (int i = np; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #else
     for (int i = 0; i < n; ++i) {
-        sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]);
+        sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]));
     }
 #endif
 
@@ -1549,7 +1549,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
     const block_q4_0 * restrict x = vx;
     const block_q4_0 * restrict y = vy;
 
-    float sumf = 0.0;
+    ggml_float sumf = 0.0;
 
 #if defined(__ARM_NEON)
     float sum0 = 0.0f;
@@ -1644,7 +1644,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
 #endif
     }
 
-    sumf = sum0 + sum1;
+    sumf = (ggml_float)(sum0 + sum1);
 #elif defined(__AVX512F__)
     // Initialize accumulator with zeros
     __m512 acc0 = _mm512_setzero_ps();
@@ -1972,13 +1972,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re
     // leftovers
     for (int i = np; i < n; ++i) {
         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
         }
     }
 #else
     for (int i = 0; i < n; ++i) {
         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
-            sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]);
+            sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]));
         }
     }
 #endif
@@ -2049,19 +2049,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) {
 #endif
 }
 
-inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   }
+inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
-inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); }
+inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
 
-static const ggml_float GELU_COEF_A    = 0.044715;
-static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
+static const float GELU_COEF_A    = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
 inline static float ggml_gelu_f32(float x) {
-    return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x)));
+    return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
 }
 
 inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
@@ -2090,7 +2090,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
 
 // Sigmoid Linear Unit (SiLU) function
 inline static float ggml_silu_f32(float x) {
-    return x/(1.0 + exp(-x));
+    return x/(1.0f + expf(-x));
 }
 
 inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
@@ -2121,7 +2121,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
     ggml_float sum = 0.0;
     for (int i = 0; i < n; ++i) {
-        sum += x[i];
+        sum += (ggml_float)x[i];
     }
     *s = sum;
 #else
@@ -2131,7 +2131,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 
 inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #ifndef GGML_USE_ACCELERATE
-    ggml_float max = -INFINITY;
+    float max = -INFINITY;
     for (int i = 0; i < n; ++i) {
         max = MAX(max, x[i]);
     }
@@ -2141,7 +2141,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) {
 #endif
 }
 
-inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); }
+inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) {
+    ggml_vec_norm_f32(n, s, x);
+    *s = 1.f/(*s);
+}
 
 //
 // logging
@@ -2540,7 +2543,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii);
                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
-                table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f));
+                table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f));
             }
 
             const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
@@ -5583,7 +5586,7 @@ static void ggml_compute_forward_norm_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    const ggml_float eps = 1e-5f; // TODO: make this a parameter
+    const float eps = 1e-5f; // TODO: make this a parameter
 
     // TODO: optimize
     for (int i03 = 0; i03 < ne03; i03++) {
@@ -5591,23 +5594,24 @@ static void ggml_compute_forward_norm_f32(
             for (int i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float mean = 0.0;
+                ggml_float sum = 0.0;
                 for (int i00 = 0; i00 < ne00; i00++) {
-                    mean += x[i00];
+                    sum += (ggml_float)x[i00];
                 }
 
-                mean /= ne00;
+                float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
                 ggml_float sum2 = 0.0;
                 for (int i00 = 0; i00 < ne00; i00++) {
-                    ggml_float v = x[i00] - mean;
+                    float v = x[i00] - mean;
                     y[i00] = v;
-                    sum2 += v*v;
+                    sum2 += (ggml_float)(v*v);
                 }
 
-                const float scale = 1.0/sqrt(sum2/ne00 + eps);
+                float variance = sum2/ne00;
+                const float scale = 1.0f/sqrtf(variance + eps);
 
                 ggml_vec_scale_f32(ne00, y, scale);
             }
@@ -5665,7 +5669,7 @@ static void ggml_compute_forward_rms_norm_f32(
     const size_t nb2 = dst->nb[2];
     const size_t nb3 = dst->nb[3];
 
-    const ggml_float eps = 1e-6f; // TODO: make this a parameter
+    const float eps = 1e-6f; // TODO: make this a parameter
 
     // TODO: optimize
     for (int i03 = 0; i03 < ne03; i03++) {
@@ -5673,12 +5677,12 @@ static void ggml_compute_forward_rms_norm_f32(
             for (int i01 = ith; i01 < ne01; i01 += nth) {
                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 
-                ggml_float mean = 0.0;
+                ggml_float sum = 0.0;
                 for (int i00 = 0; i00 < ne00; i00++) {
-                    mean += x[i00] * x[i00];
+                    sum += (ggml_float)(x[i00] * x[i00]);
                 }
 
-                mean /= ne00;
+                float mean = sum/ne00;
 
                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 
@@ -5687,7 +5691,7 @@ static void ggml_compute_forward_rms_norm_f32(
                 //     y[i00] = x[i00];
                 // }
 
-                const float scale = 1.0/sqrt(mean + eps);
+                const float scale = 1.0f/sqrtf(mean + eps);
 
                 ggml_vec_scale_f32(ne00, y, scale);
             }
@@ -6913,12 +6917,12 @@ static void ggml_compute_forward_soft_max_f32(
                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]);
-                sum += val;
+                sum += (ggml_float)val;
                 p[i] = val;
             }
         }
 
-        assert(sum > 0.0f);
+        assert(sum > 0.0);
 
         sum = 1.0/sum;
         ggml_vec_scale_f32(nc, p, sum);
@@ -6994,16 +6998,16 @@ static void ggml_compute_forward_rope_f32(
             const int p = (mode == 0 ? n_past + i2 : i2);
             for (int i1 = 0; i1 < ne1; i1++) {
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
-                    const double cos_theta = cos(p*theta);
-                    const double sin_theta = sin(p*theta);
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
 
                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    double x0 = src[0];
-                    double x1 = src[1];
+                    const float x0 = src[0];
+                    const float x1 = src[1];
 
                     dst_data[0] = x0*cos_theta - x1*sin_theta;
                     dst_data[1] = x0*sin_theta + x1*cos_theta;
@@ -7050,16 +7054,16 @@ static void ggml_compute_forward_rope_f16(
             const int p = (mode == 0 ? n_past + i2 : i2);
             for (int i1 = 0; i1 < ne1; i1++) {
                 for (int i0 = 0; i0 < n_dims; i0 += 2) {
-                    const double theta = pow(10000.0, ((double)-i0)/n_dims);
+                    const float theta = powf(10000.0, ((float)-i0)/n_dims);
 
-                    const double cos_theta = cos(p*theta);
-                    const double sin_theta = sin(p*theta);
+                    const float cos_theta = cosf(p*theta);
+                    const float sin_theta = sinf(p*theta);
 
                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 
-                    double x0 = ggml_fp16_to_fp32(src[0]);
-                    double x1 = ggml_fp16_to_fp32(src[1]);
+                    const float x0 = ggml_fp16_to_fp32(src[0]);
+                    const float x1 = ggml_fp16_to_fp32(src[1]);
 
                     dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta);
                     dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta);
@@ -7735,7 +7739,7 @@ static void ggml_compute_forward_flash_attn_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -7782,7 +7786,7 @@ static void ggml_compute_forward_flash_attn_f32(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -7803,7 +7807,7 @@ static void ggml_compute_forward_flash_attn_f32(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -7815,7 +7819,7 @@ static void ggml_compute_forward_flash_attn_f32(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -7944,7 +7948,7 @@ static void ggml_compute_forward_flash_attn_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    const float scale = 1.0/sqrt((double) D);
+    const float scale = 1.0f/sqrtf(D);
 
     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 
@@ -8008,7 +8012,7 @@ static void ggml_compute_forward_flash_attn_f16(
             float max = -INFINITY;
             ggml_vec_max_f32(M, &max, S);
 
-            float sum = 0.0f;
+            ggml_float sum = 0.0;
             {
 #ifdef GGML_SOFT_MAX_ACCELERATE
                 max = -max;
@@ -8029,7 +8033,7 @@ static void ggml_compute_forward_flash_attn_f16(
                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
                             memcpy(&scvt[j], &s, sizeof(uint16_t));
                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
-                            sump[j] += val;
+                            sump[j] += (ggml_float)val;
                             SS[j] = val;
                         }
                     }
@@ -8041,7 +8045,7 @@ static void ggml_compute_forward_flash_attn_f16(
 #endif
             }
 
-            assert(sum > 0.0f);
+            assert(sum > 0.0);
 
             sum = 1.0/sum;
             ggml_vec_scale_f32(M, S, sum);
@@ -9566,7 +9570,7 @@ label=\"%d [%d, %d] | <x>%s",
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
 label=\"<x>%.1e\"; ]\n",
-                    (void *) node, color, ggml_get_f32_1d(node, 0));
+                    (void *) node, color, (double)ggml_get_f32_1d(node, 0));
         } else {
             fprintf(fp, "  \"%p\" [ \
 style = filled; fillcolor = %s; shape = record; \
@@ -9804,7 +9808,7 @@ static enum ggml_opt_result ggml_opt_adam(
             if (params.past <= t) {
                 const float rate = (pf[t%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
@@ -9883,7 +9887,7 @@ static enum ggml_opt_result linesearch_backtracking(
     const float dec = 0.5f;
     const float inc = 2.1f;
 
-    if (*step <= 0.) {
+    if (*step <= 0.f) {
         return GGML_LINESEARCH_INVALID_PARAMETERS;
     }
 
@@ -9971,7 +9975,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
         struct ggml_cgraph * gb) {
     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
-        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) {
+        if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
             return GGML_OPT_INVALID_WOLFE;
         }
     }
@@ -10092,8 +10096,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
 
         GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0));
 
-        if (xnorm < 1.0) {
-            xnorm = 1.0;
+        if (xnorm < 1.0f) {
+            xnorm = 1.0f;
         }
         if (gnorm/xnorm <= params.lbfgs.eps) {
             // converged
@@ -10106,7 +10110,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
             if (params.past <= k) {
                 const float rate = (pf[k%params.past] - fx)/fx;
 
-                if (fabs(rate) < params.delta) {
+                if (fabsf(rate) < params.delta) {
                     return GGML_OPT_OK;
                 }
             }
diff --git a/llama.cpp b/llama.cpp
index b0eab2e..ee7eb8e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -779,8 +779,8 @@ static bool llama_model_load(
 
                 // progress
                 if (progress_callback) {
-                    double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset);
-                    double current_progress = (double(i) + current_file_progress) / double(n_parts);
+                    float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset);
+                    float current_progress = (float(i) + current_file_progress) / float(n_parts);
                     progress_callback(current_progress, progress_callback_user_data);
                 }
                 if (model.n_loaded % 8 == 0) {
@@ -922,7 +922,7 @@ static bool llama_eval_internal(
             struct ggml_tensor * KQ_scaled =
                 ggml_scale(ctx0,
                         KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
+                        ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
 
             // KQ_masked = mask_past(KQ_scaled)
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
@@ -1240,12 +1240,12 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co
 // sampling
 //
 
-static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) {
+static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) {
     // find the top k tokens
     std::partial_sort(
             logits_id.begin(),
             logits_id.begin() + top_k, logits_id.end(),
-            [](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) {
+            [](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) {
         return a.first > b.first;
     });
 
@@ -1256,9 +1256,9 @@ static llama_vocab::id llama_sample_top_p_top_k(
         llama_context & lctx,
         const std::vector<llama_vocab::id> & last_n_tokens,
         int top_k,
-        double top_p,
-        double temp,
-        double repeat_penalty) {
+        float top_p,
+        float temp,
+        float repeat_penalty) {
     auto & rng = lctx.rng;
 
     const int n_logits = lctx.model.hparams.n_vocab;
@@ -1266,17 +1266,17 @@ static llama_vocab::id llama_sample_top_p_top_k(
     const auto & logits = lctx.logits;
     const auto * plogits = logits.data() + logits.size() - n_logits;
 
-    std::vector<std::pair<double, llama_vocab::id>> logits_id;
+    std::vector<std::pair<float, llama_vocab::id>> logits_id;
     logits_id.reserve(n_logits);
 
     {
-        const double scale = 1.0/temp;
+        const float scale = 1.0f/temp;
         for (int i = 0; i < n_logits; ++i) {
             // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
             // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
             if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
                 // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
-                if (plogits[i] < 0.0) {
+                if (plogits[i] < 0.0f) {
                     logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
                 } else {
                     logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
@@ -1289,18 +1289,18 @@ static llama_vocab::id llama_sample_top_p_top_k(
 
     sample_top_k(logits_id, top_k);
 
-    double maxl = -std::numeric_limits<double>::infinity();
+    float maxl = -std::numeric_limits<float>::infinity();
     for (const auto & kv : logits_id) {
         maxl = std::max(maxl, kv.first);
     }
 
     // compute probs for the top k tokens
-    std::vector<double> probs;
+    std::vector<float> probs;
     probs.reserve(logits_id.size());
 
     double sum = 0.0;
     for (const auto & kv : logits_id) {
-        double p = exp(kv.first - maxl);
+        const float p = expf(kv.first - maxl);
         probs.push_back(p);
         sum += p;
     }
@@ -1310,8 +1310,8 @@ static llama_vocab::id llama_sample_top_p_top_k(
         p /= sum;
     }
 
-    if (top_p < 1.0f) {
-        double cumsum = 0.0f;
+    if (top_p < 1.0) {
+        double cumsum = 0.0;
         for (int i = 0; i < (int) probs.size(); i++) {
             cumsum += probs[i];
             if (cumsum >= top_p) {
@@ -1590,7 +1590,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
                 }
 
                 for (int i = 0; i < (int) hist_cur.size(); ++i) {
-                    printf("%5.3f ", hist_cur[i] / (float)nelements);
+                    printf("%5.3f ", hist_cur[i] / float(nelements));
                 }
                 printf("\n");
             } else {
@@ -1613,7 +1613,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
 
             printf("%s: hist: ", __func__);
             for (int i = 0; i < (int) hist_all.size(); ++i) {
-                printf("%5.3f ", hist_all[i] / (float)sum_all);
+                printf("%5.3f ", hist_all[i] / float(sum_all));
             }
             printf("\n");
         }
@@ -1795,9 +1795,9 @@ llama_token llama_sample_top_p_top_k(
       const llama_token * last_n_tokens_data,
                     int   last_n_tokens_size,
                     int   top_k,
-                 double   top_p,
-                 double   temp,
-                 double   repeat_penalty) {
+                  float   top_p,
+                  float   temp,
+                  float   repeat_penalty) {
     const int64_t t_start_sample_us = ggml_time_us();
 
     llama_token result = 0;
@@ -1828,11 +1828,11 @@ void llama_print_timings(struct llama_context * ctx) {
     const int32_t n_p_eval = std::max(1, ctx->n_p_eval);
 
     fprintf(stderr, "\n");
-    fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f);
-    fprintf(stderr, "%s:      sample time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample);
-    fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval);
-    fprintf(stderr, "%s:        eval time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_eval_us,   n_eval,   1e-3f * ctx->t_eval_us   / n_eval);
-    fprintf(stderr, "%s:       total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f);
+    fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
+    fprintf(stderr, "%s:      sample time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample);
+    fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval);
+    fprintf(stderr, "%s:        eval time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3 * ctx->t_eval_us,   n_eval,   1e-3 * ctx->t_eval_us   / n_eval);
+    fprintf(stderr, "%s:       total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0);
 }
 
 void llama_reset_timings(struct llama_context * ctx) {
diff --git a/llama.h b/llama.h
index d3f4cae..f5a576c 100644
--- a/llama.h
+++ b/llama.h
@@ -45,7 +45,7 @@ extern "C" {
 
     } llama_token_data;
 
-    typedef void (*llama_progress_callback)(double progress, void *ctx);
+    typedef void (*llama_progress_callback)(float progress, void *ctx);
 
     struct llama_context_params {
         int n_ctx;   // text context
@@ -134,9 +134,9 @@ extern "C" {
           const llama_token * last_n_tokens_data,
                         int   last_n_tokens_size,
                         int   top_k,
-                     double   top_p,
-                     double   temp,
-                     double   repeat_penalty);
+                      float   top_p,
+                      float   temp,
+                      float   repeat_penalty);
 
     // Performance information
     LLAMA_API void llama_print_timings(struct llama_context * ctx);
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index b44d7fe..157d733 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -5,5 +5,6 @@ function(llama_add_test source)
     add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
 endfunction()
 
+# llama_add_test(test-double-float.c) # SLOW
 llama_add_test(test-quantize.c)
 llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin)
diff --git a/tests/test-double-float.c b/tests/test-double-float.c
new file mode 100644
index 0000000..89dafc9
--- /dev/null
+++ b/tests/test-double-float.c
@@ -0,0 +1,53 @@
+// These tests may take a long time!
+// They are to prove that conversion from double to float of various functions in ggml.c doesn't affect the result.
+// This is done by checking all finite (non-NaN, non-infinite) floats.
+
+#undef NDEBUG
+#include <assert.h>
+#include <immintrin.h>
+#include <math.h>
+#include <stdint.h>
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wdouble-promotion"
+
+// ggml.c::quantize_row_q4_0_reference
+inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; }
+
+// ggml.c::ggml_silu_f32
+inline static float silu_orig(float x) {
+    return x/(1.0 + exp(-x));
+}
+
+#pragma GCC diagnostic pop
+
+// ggml.c::quantize_row_q4_0_reference
+inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; }
+
+// ggml.c::ggml_silu_f32
+inline static float silu_float(float x) {
+    return x/(1.0f + expf(-x));
+}
+
+int main(void) {
+    uint32_t x = UINT32_MAX;
+    do {
+        float f = *(float *)&x;
+        assert(!isfinite(f) || (round_orig(f) == round_float(f)));
+    } while (x--);
+
+#ifdef __F16C__
+    // GELU and SILU implementations are used with a FP16 lookup table.
+    // The original and float-only results are not equal for all inputs after converting to FP16.
+    // GELU is an approximation anyway (tanh), not tested here.
+    // For SILU, verify that the results are at least the closest floating point numbers, if the FP16 values don't match.
+    for (x = 0; x <= UINT16_MAX; x++) {
+        float f = _cvtsh_ss(x);
+        const float so = silu_orig(f);
+        const float sf = silu_float(f);
+        assert(   (_cvtss_sh(so, 0) == _cvtss_sh(sf, 0))
+               || (nextafterf(so, sf) == sf)
+               || (nextafterf(sf, so) == so));
+    }
+#endif
+}