1/*
2 * Copyright 2016-present Facebook, Inc.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <folly/Math.h>
18
19#include <algorithm>
20#include <type_traits>
21#include <utility>
22#include <vector>
23
24#include <glog/logging.h>
25
26#include <folly/Portability.h>
27#include <folly/portability/GTest.h>
28
29using namespace folly;
30using namespace folly::detail;
31
32namespace {
33
34// Workaround for https://llvm.org/bugs/show_bug.cgi?id=16404,
35// issues with __int128 multiplication and UBSAN
36template <typename T>
37T mul(T lhs, T rhs) {
38 if (rhs < 0) {
39 rhs = -rhs;
40 lhs = -lhs;
41 }
42 T accum = 0;
43 while (rhs != 0) {
44 if ((rhs & 1) != 0) {
45 accum += lhs;
46 }
47 lhs += lhs;
48 rhs >>= 1;
49 }
50 return accum;
51}
52
53template <typename T, typename B>
54T referenceDivFloor(T numer, T denom) {
55 // rv = largest integral value <= numer / denom
56 B n = numer;
57 B d = denom;
58 if (d < 0) {
59 d = -d;
60 n = -n;
61 }
62 B r = n / d;
63 while (mul(r, d) > n) {
64 --r;
65 }
66 while (mul(r + 1, d) <= n) {
67 ++r;
68 }
69 T rv = static_cast<T>(r);
70 assert(static_cast<B>(rv) == r);
71 return rv;
72}
73
74template <typename T, typename B>
75T referenceDivCeil(T numer, T denom) {
76 // rv = smallest integral value >= numer / denom
77 B n = numer;
78 B d = denom;
79 if (d < 0) {
80 d = -d;
81 n = -n;
82 }
83 B r = n / d;
84 while (mul(r, d) < n) {
85 ++r;
86 }
87 while (mul(r - 1, d) >= n) {
88 --r;
89 }
90 T rv = static_cast<T>(r);
91 assert(static_cast<B>(rv) == r);
92 return rv;
93}
94
95template <typename T, typename B>
96T referenceDivRoundAway(T numer, T denom) {
97 if ((numer < 0) != (denom < 0)) {
98 return referenceDivFloor<T, B>(numer, denom);
99 } else {
100 return referenceDivCeil<T, B>(numer, denom);
101 }
102}
103
104template <typename T>
105std::vector<T> cornerValues() {
106 std::vector<T> rv;
107 for (T i = 1; i < 24; ++i) {
108 rv.push_back(i);
109 rv.push_back(T(std::numeric_limits<T>::max() / i));
110 rv.push_back(T(std::numeric_limits<T>::max() - i));
111 rv.push_back(T(std::numeric_limits<T>::max() / T(2) - i));
112 if (std::is_signed<T>::value) {
113 rv.push_back(-i);
114 rv.push_back(T(std::numeric_limits<T>::min() / i));
115 rv.push_back(T(std::numeric_limits<T>::min() + i));
116 rv.push_back(T(std::numeric_limits<T>::min() / T(2) + i));
117 }
118 }
119 return rv;
120}
121
122template <typename A, typename B, typename C>
123void runDivTests() {
124 using T = decltype(static_cast<A>(1) / static_cast<B>(1));
125 auto numers = cornerValues<A>();
126 numers.push_back(0);
127 auto denoms = cornerValues<B>();
128 for (A n : numers) {
129 for (B d : denoms) {
130 if (std::is_signed<T>::value && n == std::numeric_limits<T>::min() &&
131 d == static_cast<T>(-1)) {
132 // n / d overflows in two's complement
133 continue;
134 }
135 EXPECT_EQ(divCeil(n, d), (referenceDivCeil<T, C>(n, d))) << n << "/" << d;
136 EXPECT_EQ(divFloor(n, d), (referenceDivFloor<T, C>(n, d)))
137 << n << "/" << d;
138 EXPECT_EQ(divTrunc(n, d), n / d) << n << "/" << d;
139 EXPECT_EQ(divRoundAway(n, d), (referenceDivRoundAway<T, C>(n, d)))
140 << n << "/" << d;
141 T nn = n;
142 T dd = d;
143 EXPECT_EQ(divCeilBranchless(nn, dd), divCeilBranchful(nn, dd));
144 EXPECT_EQ(divFloorBranchless(nn, dd), divFloorBranchful(nn, dd));
145 EXPECT_EQ(divRoundAwayBranchless(nn, dd), divRoundAwayBranchful(nn, dd));
146 }
147 }
148}
149} // namespace
150
151TEST(Bits, divTestInt8) {
152 runDivTests<int8_t, int8_t, int64_t>();
153 runDivTests<int8_t, uint8_t, int64_t>();
154 runDivTests<int8_t, int16_t, int64_t>();
155 runDivTests<int8_t, uint16_t, int64_t>();
156 runDivTests<int8_t, int32_t, int64_t>();
157 runDivTests<int8_t, uint32_t, int64_t>();
158#if FOLLY_HAVE_INT128_T
159 runDivTests<int8_t, int64_t, __int128>();
160 runDivTests<int8_t, uint64_t, __int128>();
161#endif
162}
163TEST(Bits, divTestInt16) {
164 runDivTests<int16_t, int8_t, int64_t>();
165 runDivTests<int16_t, uint8_t, int64_t>();
166 runDivTests<int16_t, int16_t, int64_t>();
167 runDivTests<int16_t, uint16_t, int64_t>();
168 runDivTests<int16_t, int32_t, int64_t>();
169 runDivTests<int16_t, uint32_t, int64_t>();
170#if FOLLY_HAVE_INT128_T
171 runDivTests<int16_t, int64_t, __int128>();
172 runDivTests<int16_t, uint64_t, __int128>();
173#endif
174}
175TEST(Bits, divTestInt32) {
176 runDivTests<int32_t, int8_t, int64_t>();
177 runDivTests<int32_t, uint8_t, int64_t>();
178 runDivTests<int32_t, int16_t, int64_t>();
179 runDivTests<int32_t, uint16_t, int64_t>();
180 runDivTests<int32_t, int32_t, int64_t>();
181 runDivTests<int32_t, uint32_t, int64_t>();
182#if FOLLY_HAVE_INT128_T
183 runDivTests<int32_t, int64_t, __int128>();
184 runDivTests<int32_t, uint64_t, __int128>();
185#endif
186}
187#if FOLLY_HAVE_INT128_T
188TEST(Bits, divTestInt64) {
189 runDivTests<int64_t, int8_t, __int128>();
190 runDivTests<int64_t, uint8_t, __int128>();
191 runDivTests<int64_t, int16_t, __int128>();
192 runDivTests<int64_t, uint16_t, __int128>();
193 runDivTests<int64_t, int32_t, __int128>();
194 runDivTests<int64_t, uint32_t, __int128>();
195 runDivTests<int64_t, int64_t, __int128>();
196 runDivTests<int64_t, uint64_t, __int128>();
197}
198#endif
199TEST(Bits, divTestUint8) {
200 runDivTests<uint8_t, int8_t, int64_t>();
201 runDivTests<uint8_t, uint8_t, int64_t>();
202 runDivTests<uint8_t, int16_t, int64_t>();
203 runDivTests<uint8_t, uint16_t, int64_t>();
204 runDivTests<uint8_t, int32_t, int64_t>();
205 runDivTests<uint8_t, uint32_t, int64_t>();
206#if FOLLY_HAVE_INT128_T
207 runDivTests<uint8_t, int64_t, __int128>();
208 runDivTests<uint8_t, uint64_t, __int128>();
209#endif
210}
211TEST(Bits, divTestUint16) {
212 runDivTests<uint16_t, int8_t, int64_t>();
213 runDivTests<uint16_t, uint8_t, int64_t>();
214 runDivTests<uint16_t, int16_t, int64_t>();
215 runDivTests<uint16_t, uint16_t, int64_t>();
216 runDivTests<uint16_t, int32_t, int64_t>();
217 runDivTests<uint16_t, uint32_t, int64_t>();
218#if FOLLY_HAVE_INT128_T
219 runDivTests<uint16_t, int64_t, __int128>();
220 runDivTests<uint16_t, uint64_t, __int128>();
221#endif
222}
223TEST(Bits, divTestUint32) {
224 runDivTests<uint32_t, int8_t, int64_t>();
225 runDivTests<uint32_t, uint8_t, int64_t>();
226 runDivTests<uint32_t, int16_t, int64_t>();
227 runDivTests<uint32_t, uint16_t, int64_t>();
228 runDivTests<uint32_t, int32_t, int64_t>();
229 runDivTests<uint32_t, uint32_t, int64_t>();
230#if FOLLY_HAVE_INT128_T
231 runDivTests<uint32_t, int64_t, __int128>();
232 runDivTests<uint32_t, uint64_t, __int128>();
233#endif
234}
235#if FOLLY_HAVE_INT128_T
236TEST(Bits, divTestUint64) {
237 runDivTests<uint64_t, int8_t, __int128>();
238 runDivTests<uint64_t, uint8_t, __int128>();
239 runDivTests<uint64_t, int16_t, __int128>();
240 runDivTests<uint64_t, uint16_t, __int128>();
241 runDivTests<uint64_t, int32_t, __int128>();
242 runDivTests<uint64_t, uint32_t, __int128>();
243 runDivTests<uint64_t, int64_t, __int128>();
244 runDivTests<uint64_t, uint64_t, __int128>();
245}
246#endif
247