1 | // |
2 | // WebSocket.cpp |
3 | // |
4 | // Library: Net |
5 | // Package: WebSocket |
6 | // Module: WebSocket |
7 | // |
8 | // Copyright (c) 2012, Applied Informatics Software Engineering GmbH. |
9 | // and Contributors. |
10 | // |
11 | // SPDX-License-Identifier: BSL-1.0 |
12 | // |
13 | |
14 | |
15 | #include "Poco/Net/WebSocket.h" |
16 | #include "Poco/Net/WebSocketImpl.h" |
17 | #include "Poco/Net/HTTPServerRequestImpl.h" |
18 | #include "Poco/Net/HTTPServerResponse.h" |
19 | #include "Poco/Net/HTTPClientSession.h" |
20 | #include "Poco/Net/HTTPServerSession.h" |
21 | #include "Poco/Net/NetException.h" |
22 | #include "Poco/MemoryStream.h" |
23 | #include "Poco/NullStream.h" |
24 | #include "Poco/BinaryWriter.h" |
25 | #include "Poco/SHA1Engine.h" |
26 | #include "Poco/Base64Encoder.h" |
27 | #include "Poco/String.h" |
28 | #include "Poco/Random.h" |
29 | #include "Poco/StreamCopier.h" |
30 | #include <sstream> |
31 | |
32 | |
33 | namespace Poco { |
34 | namespace Net { |
35 | |
36 | |
37 | const std::string WebSocket::WEBSOCKET_GUID("258EAFA5-E914-47DA-95CA-C5AB0DC85B11" ); |
38 | const std::string WebSocket::WEBSOCKET_VERSION("13" ); |
39 | HTTPCredentials WebSocket::_defaultCreds; |
40 | |
41 | |
42 | WebSocket::WebSocket(HTTPServerRequest& request, HTTPServerResponse& response): |
43 | StreamSocket(accept(request, response)) |
44 | { |
45 | } |
46 | |
47 | |
48 | WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response): |
49 | StreamSocket(connect(cs, request, response, _defaultCreds)) |
50 | { |
51 | } |
52 | |
53 | |
54 | WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials): |
55 | StreamSocket(connect(cs, request, response, credentials)) |
56 | { |
57 | } |
58 | |
59 | |
60 | WebSocket::WebSocket(const Socket& socket): |
61 | StreamSocket(socket) |
62 | { |
63 | if (!dynamic_cast<WebSocketImpl*>(impl())) |
64 | throw InvalidArgumentException("Cannot assign incompatible socket" ); |
65 | } |
66 | |
67 | |
68 | WebSocket::~WebSocket() |
69 | { |
70 | } |
71 | |
72 | |
73 | WebSocket& WebSocket::operator = (const Socket& socket) |
74 | { |
75 | if (dynamic_cast<WebSocketImpl*>(socket.impl())) |
76 | Socket::operator = (socket); |
77 | else |
78 | throw InvalidArgumentException("Cannot assign incompatible socket" ); |
79 | return *this; |
80 | } |
81 | |
82 | |
83 | void WebSocket::shutdown() |
84 | { |
85 | shutdown(WS_NORMAL_CLOSE); |
86 | } |
87 | |
88 | |
89 | void WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage) |
90 | { |
91 | Poco::Buffer<char> buffer(statusMessage.size() + 2); |
92 | Poco::MemoryOutputStream ostr(buffer.begin(), buffer.size()); |
93 | Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER); |
94 | writer << statusCode; |
95 | writer.writeRaw(statusMessage); |
96 | sendFrame(buffer.begin(), static_cast<int>(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE); |
97 | } |
98 | |
99 | |
100 | int WebSocket::sendFrame(const void* buffer, int length, int flags) |
101 | { |
102 | flags |= FRAME_OP_SETRAW; |
103 | return static_cast<WebSocketImpl*>(impl())->sendBytes(buffer, length, flags); |
104 | } |
105 | |
106 | |
107 | int WebSocket::receiveFrame(void* buffer, int length, int& flags) |
108 | { |
109 | int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, length, 0); |
110 | flags = static_cast<WebSocketImpl*>(impl())->frameFlags(); |
111 | return n; |
112 | } |
113 | |
114 | |
115 | int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags) |
116 | { |
117 | int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0); |
118 | flags = static_cast<WebSocketImpl*>(impl())->frameFlags(); |
119 | return n; |
120 | } |
121 | |
122 | |
123 | WebSocket::Mode WebSocket::mode() const |
124 | { |
125 | return static_cast<WebSocketImpl*>(impl())->mustMaskPayload() ? WS_CLIENT : WS_SERVER; |
126 | } |
127 | |
128 | |
129 | void WebSocket::setMaxPayloadSize(int maxPayloadSize) |
130 | { |
131 | static_cast<WebSocketImpl*>(impl())->setMaxPayloadSize(maxPayloadSize); |
132 | } |
133 | |
134 | |
135 | int WebSocket::getMaxPayloadSize() const |
136 | { |
137 | return static_cast<WebSocketImpl*>(impl())->getMaxPayloadSize(); |
138 | } |
139 | |
140 | |
141 | WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response) |
142 | { |
143 | if (request.hasToken("Connection" , "upgrade" ) && icompare(request.get("Upgrade" , "" ), "websocket" ) == 0) |
144 | { |
145 | std::string version = request.get("Sec-WebSocket-Version" , "" ); |
146 | if (version.empty()) throw WebSocketException("Missing Sec-WebSocket-Version in handshake request" , WS_ERR_HANDSHAKE_NO_VERSION); |
147 | if (version != WEBSOCKET_VERSION) throw WebSocketException("Unsupported WebSocket version requested" , version, WS_ERR_HANDSHAKE_UNSUPPORTED_VERSION); |
148 | std::string key = request.get("Sec-WebSocket-Key" , "" ); |
149 | Poco::trimInPlace(key); |
150 | if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request" , WS_ERR_HANDSHAKE_NO_KEY); |
151 | |
152 | response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS); |
153 | response.set("Upgrade" , "websocket" ); |
154 | response.set("Connection" , "Upgrade" ); |
155 | response.set("Sec-WebSocket-Accept" , computeAccept(key)); |
156 | response.setContentLength(HTTPResponse::UNKNOWN_CONTENT_LENGTH); |
157 | response.send().flush(); |
158 | |
159 | HTTPServerRequestImpl& requestImpl = static_cast<HTTPServerRequestImpl&>(request); |
160 | return new WebSocketImpl(static_cast<StreamSocketImpl*>(requestImpl.detachSocket().impl()), requestImpl.session(), false); |
161 | } |
162 | else throw WebSocketException("No WebSocket handshake" , WS_ERR_NO_HANDSHAKE); |
163 | } |
164 | |
165 | |
166 | WebSocketImpl* WebSocket::connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials) |
167 | { |
168 | if (!cs.getProxyHost().empty() && !cs.secure()) |
169 | { |
170 | cs.proxyTunnel(); |
171 | } |
172 | std::string key = createKey(); |
173 | request.set("Connection" , "Upgrade" ); |
174 | request.set("Upgrade" , "websocket" ); |
175 | request.set("Sec-WebSocket-Version" , WEBSOCKET_VERSION); |
176 | request.set("Sec-WebSocket-Key" , key); |
177 | request.setChunkedTransferEncoding(false); |
178 | cs.setKeepAlive(true); |
179 | cs.sendRequest(request); |
180 | std::istream& istr = cs.receiveResponse(response); |
181 | if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS) |
182 | { |
183 | return completeHandshake(cs, response, key); |
184 | } |
185 | else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED) |
186 | { |
187 | Poco::NullOutputStream null; |
188 | Poco::StreamCopier::copyStream(istr, null); |
189 | credentials.authenticate(request, response); |
190 | if (!cs.getProxyHost().empty() && !cs.secure()) |
191 | { |
192 | cs.reset(); |
193 | cs.proxyTunnel(); |
194 | } |
195 | cs.sendRequest(request); |
196 | cs.receiveResponse(response); |
197 | if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS) |
198 | { |
199 | return completeHandshake(cs, response, key); |
200 | } |
201 | else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED) |
202 | { |
203 | throw WebSocketException("Not authorized" , WS_ERR_UNAUTHORIZED); |
204 | } |
205 | } |
206 | if (response.getStatus() == HTTPResponse::HTTP_OK) |
207 | { |
208 | throw WebSocketException("The server does not understand the WebSocket protocol" , WS_ERR_NO_HANDSHAKE); |
209 | } |
210 | else |
211 | { |
212 | throw WebSocketException("Cannot upgrade to WebSocket connection" , response.getReason(), WS_ERR_NO_HANDSHAKE); |
213 | } |
214 | } |
215 | |
216 | |
217 | WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key) |
218 | { |
219 | std::string connection = response.get("Connection" , "" ); |
220 | if (Poco::icompare(connection, "Upgrade" ) != 0) |
221 | throw WebSocketException("No Connection: Upgrade header in handshake response" , WS_ERR_NO_HANDSHAKE); |
222 | std::string upgrade = response.get("Upgrade" , "" ); |
223 | if (Poco::icompare(upgrade, "websocket" ) != 0) |
224 | throw WebSocketException("No Upgrade: websocket header in handshake response" , WS_ERR_NO_HANDSHAKE); |
225 | std::string accept = response.get("Sec-WebSocket-Accept" , "" ); |
226 | if (accept != computeAccept(key)) |
227 | throw WebSocketException("Invalid or missing Sec-WebSocket-Accept header in handshake response" , WS_ERR_HANDSHAKE_ACCEPT); |
228 | return new WebSocketImpl(static_cast<StreamSocketImpl*>(cs.detachSocket().impl()), cs, true); |
229 | } |
230 | |
231 | |
232 | std::string WebSocket::createKey() |
233 | { |
234 | Poco::Random rnd; |
235 | std::ostringstream ostr; |
236 | Poco::Base64Encoder base64(ostr); |
237 | Poco::BinaryWriter writer(base64); |
238 | writer << rnd.next() << rnd.next() << rnd.next() << rnd.next(); |
239 | base64.close(); |
240 | return ostr.str(); |
241 | } |
242 | |
243 | |
244 | std::string WebSocket::computeAccept(const std::string& key) |
245 | { |
246 | std::string accept(key); |
247 | accept += WEBSOCKET_GUID; |
248 | Poco::SHA1Engine sha1; |
249 | sha1.update(accept); |
250 | Poco::DigestEngine::Digest d = sha1.digest(); |
251 | std::ostringstream ostr; |
252 | Poco::Base64Encoder base64(ostr); |
253 | base64.write(reinterpret_cast<const char*>(&d[0]), d.size()); |
254 | base64.close(); |
255 | return ostr.str(); |
256 | } |
257 | |
258 | |
259 | } } // namespace Poco::Net |
260 | |