1 | /* |
2 | * IXWebSocketHandshake.h |
3 | * Author: Benjamin Sergeant |
4 | * Copyright (c) 2019 Machine Zone, Inc. All rights reserved. |
5 | */ |
6 | |
7 | #include "IXWebSocketHandshake.h" |
8 | |
9 | #include "IXHttp.h" |
10 | #include "IXSocketConnect.h" |
11 | #include "IXStrCaseCompare.h" |
12 | #include "IXUrlParser.h" |
13 | #include "IXUserAgent.h" |
14 | #include "IXWebSocketHandshakeKeyGen.h" |
15 | #include <algorithm> |
16 | #include <iostream> |
17 | #include <random> |
18 | #include <sstream> |
19 | |
20 | |
21 | namespace ix |
22 | { |
23 | WebSocketHandshake::WebSocketHandshake( |
24 | std::atomic<bool>& requestInitCancellation, |
25 | std::unique_ptr<Socket>& socket, |
26 | WebSocketPerMessageDeflatePtr& perMessageDeflate, |
27 | WebSocketPerMessageDeflateOptions& perMessageDeflateOptions, |
28 | std::atomic<bool>& enablePerMessageDeflate) |
29 | : _requestInitCancellation(requestInitCancellation) |
30 | , _socket(socket) |
31 | , _perMessageDeflate(perMessageDeflate) |
32 | , _perMessageDeflateOptions(perMessageDeflateOptions) |
33 | , _enablePerMessageDeflate(enablePerMessageDeflate) |
34 | { |
35 | } |
36 | |
37 | bool WebSocketHandshake::insensitiveStringCompare(const std::string& a, const std::string& b) |
38 | { |
39 | return CaseInsensitiveLess::cmp(a, b) == 0; |
40 | } |
41 | |
42 | std::string WebSocketHandshake::genRandomString(const int len) |
43 | { |
44 | std::string alphanum = "0123456789" |
45 | "ABCDEFGH" |
46 | "abcdefgh" ; |
47 | |
48 | std::random_device r; |
49 | std::default_random_engine e1(r()); |
50 | std::uniform_int_distribution<int> dist(0, (int) alphanum.size() - 1); |
51 | |
52 | std::string s; |
53 | s.resize(len); |
54 | |
55 | for (int i = 0; i < len; ++i) |
56 | { |
57 | int x = dist(e1); |
58 | s[i] = alphanum[x]; |
59 | } |
60 | |
61 | return s; |
62 | } |
63 | |
64 | WebSocketInitResult WebSocketHandshake::sendErrorResponse(int code, const std::string& reason) |
65 | { |
66 | std::stringstream ss; |
67 | ss << "HTTP/1.1 " ; |
68 | ss << code; |
69 | ss << " " ; |
70 | ss << reason; |
71 | ss << "\r\n" ; |
72 | ss << "Server: " << userAgent() << "\r\n" ; |
73 | |
74 | // Socket write can only be cancelled through a timeout here, not manually. |
75 | static std::atomic<bool> requestInitCancellation(false); |
76 | auto isCancellationRequested = |
77 | makeCancellationRequestWithTimeout(1, requestInitCancellation); |
78 | |
79 | if (!_socket->writeBytes(ss.str(), isCancellationRequested)) |
80 | { |
81 | return WebSocketInitResult(false, 500, "Timed out while sending error response" ); |
82 | } |
83 | |
84 | return WebSocketInitResult(false, code, reason); |
85 | } |
86 | |
87 | WebSocketInitResult WebSocketHandshake::clientHandshake( |
88 | const std::string& url, |
89 | const WebSocketHttpHeaders& , |
90 | const std::string& host, |
91 | const std::string& path, |
92 | int port, |
93 | int timeoutSecs) |
94 | { |
95 | _requestInitCancellation = false; |
96 | |
97 | auto isCancellationRequested = |
98 | makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); |
99 | |
100 | std::string errMsg; |
101 | bool success = _socket->connect(host, port, errMsg, isCancellationRequested); |
102 | if (!success) |
103 | { |
104 | std::stringstream ss; |
105 | ss << "Unable to connect to " << host << " on port " << port << ", error: " << errMsg; |
106 | return WebSocketInitResult(false, 0, ss.str()); |
107 | } |
108 | |
109 | // |
110 | // Generate a random 24 bytes string which looks like it is base64 encoded |
111 | // y3JJHMbDL1EzLkh9GBhXDw== |
112 | // 0cb3Vd9HkbpVVumoS3Noka== |
113 | // |
114 | // See https://stackoverflow.com/questions/18265128/what-is-sec-websocket-key-for |
115 | // |
116 | std::string secWebSocketKey = genRandomString(22); |
117 | secWebSocketKey += "==" ; |
118 | |
119 | std::stringstream ss; |
120 | ss << "GET " << path << " HTTP/1.1\r\n" ; |
121 | ss << "Host: " << host << ":" << port << "\r\n" ; |
122 | ss << "Upgrade: websocket\r\n" ; |
123 | ss << "Connection: Upgrade\r\n" ; |
124 | ss << "Sec-WebSocket-Version: 13\r\n" ; |
125 | ss << "Sec-WebSocket-Key: " << secWebSocketKey << "\r\n" ; |
126 | |
127 | // User-Agent can be customized by users |
128 | if (extraHeaders.find("User-Agent" ) == extraHeaders.end()) |
129 | { |
130 | ss << "User-Agent: " << userAgent() << "\r\n" ; |
131 | } |
132 | |
133 | for (auto& it : extraHeaders) |
134 | { |
135 | ss << it.first << ": " << it.second << "\r\n" ; |
136 | } |
137 | |
138 | if (_enablePerMessageDeflate) |
139 | { |
140 | ss << _perMessageDeflateOptions.generateHeader(); |
141 | } |
142 | |
143 | ss << "\r\n" ; |
144 | |
145 | if (!_socket->writeBytes(ss.str(), isCancellationRequested)) |
146 | { |
147 | return WebSocketInitResult( |
148 | false, 0, std::string("Failed sending GET request to " ) + url); |
149 | } |
150 | |
151 | // Read HTTP status line |
152 | auto lineResult = _socket->readLine(isCancellationRequested); |
153 | auto lineValid = lineResult.first; |
154 | auto line = lineResult.second; |
155 | |
156 | if (!lineValid) |
157 | { |
158 | return WebSocketInitResult( |
159 | false, 0, std::string("Failed reading HTTP status line from " ) + url); |
160 | } |
161 | |
162 | // Validate status |
163 | auto statusLine = Http::parseStatusLine(line); |
164 | std::string httpVersion = statusLine.first; |
165 | int status = statusLine.second; |
166 | |
167 | // HTTP/1.0 is too old. |
168 | if (httpVersion != "HTTP/1.1" ) |
169 | { |
170 | std::stringstream ss; |
171 | ss << "Expecting HTTP/1.1, got " << httpVersion << ". " |
172 | << "Rejecting connection to " << url << ", status: " << status |
173 | << ", HTTP Status line: " << line; |
174 | return WebSocketInitResult(false, status, ss.str()); |
175 | } |
176 | |
177 | auto result = parseHttpHeaders(_socket, isCancellationRequested); |
178 | auto = result.first; |
179 | auto = result.second; |
180 | |
181 | if (!headersValid) |
182 | { |
183 | return WebSocketInitResult(false, status, "Error parsing HTTP headers" ); |
184 | } |
185 | |
186 | // We want an 101 HTTP status for websocket, otherwise it could be |
187 | // a redirection (like 301) |
188 | if (status != 101) |
189 | { |
190 | std::stringstream ss; |
191 | ss << "Expecting status 101 (Switching Protocol), got " << status |
192 | << " status connecting to " << url << ", HTTP Status line: " << line; |
193 | |
194 | return WebSocketInitResult(false, status, ss.str(), headers, path); |
195 | } |
196 | |
197 | // Check the presence of the connection field |
198 | if (headers.find("connection" ) == headers.end()) |
199 | { |
200 | std::string errorMsg("Missing connection value" ); |
201 | return WebSocketInitResult(false, status, errorMsg); |
202 | } |
203 | |
204 | // Check the value of the connection field |
205 | // Some websocket servers (Go/Gorilla?) send lowercase values for the |
206 | // connection header, so do a case insensitive comparison |
207 | // |
208 | // See https://github.com/apache/thrift/commit/7c4bdf9914fcba6c89e0f69ae48b9675578f084a |
209 | // |
210 | if (!insensitiveStringCompare(headers["connection" ], "Upgrade" )) |
211 | { |
212 | std::stringstream ss; |
213 | ss << "Invalid connection value: " << headers["connection" ]; |
214 | return WebSocketInitResult(false, status, ss.str()); |
215 | } |
216 | |
217 | char output[29] = {}; |
218 | WebSocketHandshakeKeyGen::generate(secWebSocketKey, output); |
219 | if (std::string(output) != headers["sec-websocket-accept" ]) |
220 | { |
221 | std::string errorMsg("Invalid Sec-WebSocket-Accept value" ); |
222 | return WebSocketInitResult(false, status, errorMsg); |
223 | } |
224 | |
225 | if (_enablePerMessageDeflate) |
226 | { |
227 | // Parse the server response. Does it support deflate ? |
228 | std::string = headers["sec-websocket-extensions" ]; |
229 | WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); |
230 | |
231 | // If the server does not support that extension, disable it. |
232 | if (!webSocketPerMessageDeflateOptions.enabled()) |
233 | { |
234 | _enablePerMessageDeflate = false; |
235 | } |
236 | // Otherwise try to initialize the deflate engine (zlib) |
237 | else if (!_perMessageDeflate->init(webSocketPerMessageDeflateOptions)) |
238 | { |
239 | return WebSocketInitResult( |
240 | false, 0, "Failed to initialize per message deflate engine" ); |
241 | } |
242 | } |
243 | |
244 | return WebSocketInitResult(true, status, "" , headers, path); |
245 | } |
246 | |
247 | WebSocketInitResult WebSocketHandshake::serverHandshake(int timeoutSecs, |
248 | bool enablePerMessageDeflate) |
249 | { |
250 | _requestInitCancellation = false; |
251 | |
252 | auto isCancellationRequested = |
253 | makeCancellationRequestWithTimeout(timeoutSecs, _requestInitCancellation); |
254 | |
255 | // Read first line |
256 | auto lineResult = _socket->readLine(isCancellationRequested); |
257 | auto lineValid = lineResult.first; |
258 | auto line = lineResult.second; |
259 | |
260 | if (!lineValid) |
261 | { |
262 | return sendErrorResponse(400, "Error reading HTTP request line" ); |
263 | } |
264 | |
265 | // Validate request line (GET /foo HTTP/1.1\r\n) |
266 | auto requestLine = Http::parseRequestLine(line); |
267 | auto method = std::get<0>(requestLine); |
268 | auto uri = std::get<1>(requestLine); |
269 | auto httpVersion = std::get<2>(requestLine); |
270 | |
271 | if (method != "GET" ) |
272 | { |
273 | return sendErrorResponse(400, "Invalid HTTP method, need GET, got " + method); |
274 | } |
275 | |
276 | if (httpVersion != "HTTP/1.1" ) |
277 | { |
278 | return sendErrorResponse(400, |
279 | "Invalid HTTP version, need HTTP/1.1, got: " + httpVersion); |
280 | } |
281 | |
282 | // Retrieve and validate HTTP headers |
283 | auto result = parseHttpHeaders(_socket, isCancellationRequested); |
284 | auto = result.first; |
285 | auto = result.second; |
286 | |
287 | if (!headersValid) |
288 | { |
289 | return sendErrorResponse(400, "Error parsing HTTP headers" ); |
290 | } |
291 | |
292 | if (headers.find("sec-websocket-key" ) == headers.end()) |
293 | { |
294 | return sendErrorResponse(400, "Missing Sec-WebSocket-Key value" ); |
295 | } |
296 | |
297 | if (headers.find("upgrade" ) == headers.end()) |
298 | { |
299 | return sendErrorResponse(400, "Missing Upgrade header" ); |
300 | } |
301 | |
302 | if (!insensitiveStringCompare(headers["upgrade" ], "WebSocket" ) && |
303 | headers["Upgrade" ] != "keep-alive, Upgrade" ) // special case for firefox |
304 | { |
305 | return sendErrorResponse(400, |
306 | "Invalid Upgrade header, " |
307 | "need WebSocket, got " + |
308 | headers["upgrade" ]); |
309 | } |
310 | |
311 | if (headers.find("sec-websocket-version" ) == headers.end()) |
312 | { |
313 | return sendErrorResponse(400, "Missing Sec-WebSocket-Version value" ); |
314 | } |
315 | |
316 | { |
317 | std::stringstream ss; |
318 | ss << headers["sec-websocket-version" ]; |
319 | int version; |
320 | ss >> version; |
321 | |
322 | if (version != 13) |
323 | { |
324 | return sendErrorResponse(400, |
325 | "Invalid Sec-WebSocket-Version, " |
326 | "need 13, got " + |
327 | ss.str()); |
328 | } |
329 | } |
330 | |
331 | char output[29] = {}; |
332 | WebSocketHandshakeKeyGen::generate(headers["sec-websocket-key" ], output); |
333 | |
334 | std::stringstream ss; |
335 | ss << "HTTP/1.1 101 Switching Protocols\r\n" ; |
336 | ss << "Sec-WebSocket-Accept: " << std::string(output) << "\r\n" ; |
337 | ss << "Upgrade: websocket\r\n" ; |
338 | ss << "Connection: Upgrade\r\n" ; |
339 | ss << "Server: " << userAgent() << "\r\n" ; |
340 | |
341 | // Parse the client headers. Does it support deflate ? |
342 | std::string = headers["sec-websocket-extensions" ]; |
343 | WebSocketPerMessageDeflateOptions webSocketPerMessageDeflateOptions(header); |
344 | |
345 | // If the client has requested that extension, |
346 | if (webSocketPerMessageDeflateOptions.enabled() && enablePerMessageDeflate) |
347 | { |
348 | _enablePerMessageDeflate = true; |
349 | |
350 | if (!_perMessageDeflate->init(webSocketPerMessageDeflateOptions)) |
351 | { |
352 | return WebSocketInitResult( |
353 | false, 0, "Failed to initialize per message deflate engine" ); |
354 | } |
355 | ss << webSocketPerMessageDeflateOptions.generateHeader(); |
356 | } |
357 | |
358 | ss << "\r\n" ; |
359 | |
360 | if (!_socket->writeBytes(ss.str(), isCancellationRequested)) |
361 | { |
362 | return WebSocketInitResult( |
363 | false, 0, std::string("Failed sending response to remote end" )); |
364 | } |
365 | |
366 | return WebSocketInitResult(true, 200, "" , headers, uri); |
367 | } |
368 | } // namespace ix |
369 | |