From 683852b2501ac7207e1dee048e5492da6568bc63 Mon Sep 17 00:00:00 2001
From: namark <namark@disroot.org>
Date: Thu, 9 Jul 2020 03:04:48 +0400
Subject: [PATCH] WIP clever midpoint, so far just unsigned,

with tests stolen from libc++.
---
 source/simple/support/algorithm.hpp |  26 ++++-
 unit_tests/libcpp_midpoint.cpp      | 144 ++++++++++++++++++++++++++++
 2 files changed, 169 insertions(+), 1 deletion(-)
 create mode 100644 unit_tests/libcpp_midpoint.cpp

diff --git a/source/simple/support/algorithm.hpp b/source/simple/support/algorithm.hpp
index 696d9dc..2811321 100644
--- a/source/simple/support/algorithm.hpp
+++ b/source/simple/support/algorithm.hpp
@@ -6,6 +6,7 @@
 #include <algorithm>
 
 #include "range.hpp"
+#include "arithmetic.hpp"
 
 namespace simple::support
 {
@@ -273,6 +274,7 @@ namespace simple::support
 	[[nodiscard]] constexpr
 	auto average(Numbers... n)
 	//TODO: noexcept account for return value construction
+	//TODO: cast size to result type of sum instead of int
 	noexcept(noexcept((n + ...) / int(sizeof...(n))))
 	{
 		return (n + ...) / int(sizeof...(n));
@@ -281,11 +283,33 @@ namespace simple::support
 	template <typename Number>
 	[[nodiscard]] constexpr
 	Number midpoint(Number a, Number b)
-	noexcept(noexcept(Number(a + (b - a) / 2)))
+	noexcept(noexcept(Number(a + (b - a)/2)))
 	{
 		return a + (b - a)/2;
 	}
 
+	template <typename Unsigned>
+	[[nodiscard]] constexpr
+	Unsigned midpoint_overflow(Unsigned a, Unsigned b)
+	// noexcept(noexcept(TODO))
+	{
+
+		Unsigned diff{};
+		bool overflew = sub_overflow(diff,b,a);
+
+		// manual idiv
+		Unsigned idiff = -diff; // neg
+		idiff /= Unsigned{2}; // div
+		idiff = -idiff; // neg
+		// or... 0 - (0 - diff)/2, this is midpointception! TODO: -_-
+
+		diff /= Unsigned{2};
+
+		return a + (overflew ? idiff : diff);
+		// return a + overflew * idiff + ~overflew * diff;
+
+	}
+
 	// std::swap is not constexpr >.<
 	template <typename T, std::enable_if_t<
 		std::is_move_constructible_v<T> &&
diff --git a/unit_tests/libcpp_midpoint.cpp b/unit_tests/libcpp_midpoint.cpp
new file mode 100644
index 0000000..c745d50
--- /dev/null
+++ b/unit_tests/libcpp_midpoint.cpp
@@ -0,0 +1,144 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// UNSUPPORTED: c++98, c++03, c++11, c++14, c++17
+// <numeric>
+
+// template <class _Tp>
+// _Tp midpoint_overflow(_Tp __a, _Tp __b) noexcept
+//
+
+#include <stdint.h>
+#include <limits>
+#include <numeric>
+#include <cassert>
+#include "simple/support/algorithm.hpp"
+
+using namespace simple::support;
+
+template <typename T>
+void signed_test()
+{
+    constexpr T zero{0};
+    constexpr T one{1};
+    constexpr T two{2};
+    constexpr T three{3};
+    constexpr T four{4};
+
+    // ASSERT_SAME_TYPE(decltype(midpoint_overflow(T(), T())), T);
+    // ASSERT_NOEXCEPT(          midpoint_overflow(T(), T()));
+    using limits = std::numeric_limits<T>;
+
+    static_assert(midpoint_overflow(one, three) == two, "");
+    static_assert(midpoint_overflow(three, one) == two, "");
+
+    assert(midpoint_overflow(zero, zero) == zero);
+    assert(midpoint_overflow(zero, two)  == one);
+    assert(midpoint_overflow(two, zero)  == one);
+    assert(midpoint_overflow(two, two)   == two);
+
+    assert(midpoint_overflow(one, four)    == two);
+    assert(midpoint_overflow(four, one)    == three);
+    assert(midpoint_overflow(three, four)  == three);
+    assert(midpoint_overflow(four, three)  == four);
+
+    assert(midpoint_overflow(T( 3), T( 4)) == T(3));
+    assert(midpoint_overflow(T( 4), T( 3)) == T(4));
+    assert(midpoint_overflow(T(-3), T( 4)) == T(0));
+    assert(midpoint_overflow(T(-4), T( 3)) == T(-1));
+    assert(midpoint_overflow(T( 3), T(-4)) == T(0));
+    assert(midpoint_overflow(T( 4), T(-3)) == T(1));
+    assert(midpoint_overflow(T(-3), T(-4)) == T(-3));
+    assert(midpoint_overflow(T(-4), T(-3)) == T(-4));
+
+    static_assert(midpoint_overflow(limits::min(), limits::max()) == T(-1), "");
+    static_assert(midpoint_overflow(limits::max(), limits::min()) == T( 0), "");
+
+    static_assert(midpoint_overflow(limits::min(), T(6)) == limits::min()/2 + 3, "");
+    assert(       midpoint_overflow(T(6), limits::min()) == limits::min()/2 + 3);
+    assert(       midpoint_overflow(limits::max(), T(6)) == limits::max()/2 + 4);
+    static_assert(midpoint_overflow(T(6), limits::max()) == limits::max()/2 + 3, "");
+
+    assert(       midpoint_overflow(limits::min(), T(-6)) == limits::min()/2 - 3);
+    static_assert(midpoint_overflow(T(-6), limits::min()) == limits::min()/2 - 3, "");
+    static_assert(midpoint_overflow(limits::max(), T(-6)) == limits::max()/2 - 2, "");
+    assert(       midpoint_overflow(T(-6), limits::max()) == limits::max()/2 - 3);
+}
+
+template <typename T>
+void unsigned_test()
+{
+    constexpr T zero{0};
+    constexpr T one{1};
+    constexpr T two{2};
+    constexpr T three{3};
+    constexpr T four{4};
+
+    // ASSERT_SAME_TYPE(decltype(midpoint_overflow(T(), T())), T);
+    // ASSERT_NOEXCEPT(          midpoint_overflow(T(), T()));
+    using limits = std::numeric_limits<T>;
+    const T half_way = (limits::max() - limits::min())/2;
+
+    static_assert(midpoint_overflow(one, three) == two, "");
+    static_assert(midpoint_overflow(three, one) == two, "");
+
+    assert(midpoint_overflow(zero, zero) == zero);
+    assert(midpoint_overflow(zero, two)  == one);
+    assert(midpoint_overflow(two, zero)  == one);
+    assert(midpoint_overflow(two, two)   == two);
+
+    assert(midpoint_overflow(one, four)    == two);
+    assert(midpoint_overflow(four, one)    == three);
+    assert(midpoint_overflow(three, four)  == three);
+    assert(midpoint_overflow(four, three)  == four);
+
+    assert(midpoint_overflow(limits::min(), limits::max()) == T(half_way));
+    assert(midpoint_overflow(limits::max(), limits::min()) == T(half_way + 1));
+
+    static_assert(midpoint_overflow(limits::min(), T(6)) == limits::min()/2 + 3, "");
+    assert(       midpoint_overflow(T(6), limits::min()) == limits::min()/2 + 3);
+    assert(       midpoint_overflow(limits::max(), T(6)) == half_way + 4);
+    static_assert(midpoint_overflow(T(6), limits::max()) == half_way + 3, "");
+}
+
+
+int main(int, char**)
+{
+    // signed_test<signed char>();
+    // signed_test<short>();
+    // signed_test<int>();
+    // signed_test<long>();
+    // signed_test<long long>();
+    //
+    // signed_test<int8_t>();
+    // signed_test<int16_t>();
+    // signed_test<int32_t>();
+    // signed_test<int64_t>();
+
+    unsigned_test<unsigned char>();
+    unsigned_test<unsigned short>();
+    unsigned_test<unsigned int>();
+    unsigned_test<unsigned long>();
+    unsigned_test<unsigned long long>();
+
+    unsigned_test<uint8_t>();
+    unsigned_test<uint16_t>();
+    unsigned_test<uint32_t>();
+    unsigned_test<uint64_t>();
+
+#ifndef _LIBCPP_HAS_NO_INT128
+    // unsigned_test<__uint128_t>();
+    // signed_test<__int128_t>();
+#endif
+
+//     int_test<char>();
+    // signed_test<ptrdiff_t>();
+    unsigned_test<size_t>();
+
+    return 0;
+}
-- 
GitLab