1//
2// WebSocket.cpp
3//
4// Library: Net
5// Package: WebSocket
6// Module: WebSocket
7//
8// Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
9// and Contributors.
10//
11// SPDX-License-Identifier: BSL-1.0
12//
13
14
15#include "Poco/Net/WebSocket.h"
16#include "Poco/Net/WebSocketImpl.h"
17#include "Poco/Net/HTTPServerRequestImpl.h"
18#include "Poco/Net/HTTPServerResponse.h"
19#include "Poco/Net/HTTPClientSession.h"
20#include "Poco/Net/HTTPServerSession.h"
21#include "Poco/Net/NetException.h"
22#include "Poco/MemoryStream.h"
23#include "Poco/NullStream.h"
24#include "Poco/BinaryWriter.h"
25#include "Poco/SHA1Engine.h"
26#include "Poco/Base64Encoder.h"
27#include "Poco/String.h"
28#include "Poco/Random.h"
29#include "Poco/StreamCopier.h"
30#include <sstream>
31
32
33namespace Poco {
34namespace Net {
35
36
37const std::string WebSocket::WEBSOCKET_GUID("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
38const std::string WebSocket::WEBSOCKET_VERSION("13");
39HTTPCredentials WebSocket::_defaultCreds;
40
41
42WebSocket::WebSocket(HTTPServerRequest& request, HTTPServerResponse& response):
43 StreamSocket(accept(request, response))
44{
45}
46
47
48WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response):
49 StreamSocket(connect(cs, request, response, _defaultCreds))
50{
51}
52
53
54WebSocket::WebSocket(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials):
55 StreamSocket(connect(cs, request, response, credentials))
56{
57}
58
59
60WebSocket::WebSocket(const Socket& socket):
61 StreamSocket(socket)
62{
63 if (!dynamic_cast<WebSocketImpl*>(impl()))
64 throw InvalidArgumentException("Cannot assign incompatible socket");
65}
66
67
68WebSocket::~WebSocket()
69{
70}
71
72
73WebSocket& WebSocket::operator = (const Socket& socket)
74{
75 if (dynamic_cast<WebSocketImpl*>(socket.impl()))
76 Socket::operator = (socket);
77 else
78 throw InvalidArgumentException("Cannot assign incompatible socket");
79 return *this;
80}
81
82
83void WebSocket::shutdown()
84{
85 shutdown(WS_NORMAL_CLOSE);
86}
87
88
89void WebSocket::shutdown(Poco::UInt16 statusCode, const std::string& statusMessage)
90{
91 Poco::Buffer<char> buffer(statusMessage.size() + 2);
92 Poco::MemoryOutputStream ostr(buffer.begin(), buffer.size());
93 Poco::BinaryWriter writer(ostr, Poco::BinaryWriter::NETWORK_BYTE_ORDER);
94 writer << statusCode;
95 writer.writeRaw(statusMessage);
96 sendFrame(buffer.begin(), static_cast<int>(ostr.charsWritten()), FRAME_FLAG_FIN | FRAME_OP_CLOSE);
97}
98
99
100int WebSocket::sendFrame(const void* buffer, int length, int flags)
101{
102 flags |= FRAME_OP_SETRAW;
103 return static_cast<WebSocketImpl*>(impl())->sendBytes(buffer, length, flags);
104}
105
106
107int WebSocket::receiveFrame(void* buffer, int length, int& flags)
108{
109 int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, length, 0);
110 flags = static_cast<WebSocketImpl*>(impl())->frameFlags();
111 return n;
112}
113
114
115int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags)
116{
117 int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0);
118 flags = static_cast<WebSocketImpl*>(impl())->frameFlags();
119 return n;
120}
121
122
123WebSocket::Mode WebSocket::mode() const
124{
125 return static_cast<WebSocketImpl*>(impl())->mustMaskPayload() ? WS_CLIENT : WS_SERVER;
126}
127
128
129void WebSocket::setMaxPayloadSize(int maxPayloadSize)
130{
131 static_cast<WebSocketImpl*>(impl())->setMaxPayloadSize(maxPayloadSize);
132}
133
134
135int WebSocket::getMaxPayloadSize() const
136{
137 return static_cast<WebSocketImpl*>(impl())->getMaxPayloadSize();
138}
139
140
141WebSocketImpl* WebSocket::accept(HTTPServerRequest& request, HTTPServerResponse& response)
142{
143 if (request.hasToken("Connection", "upgrade") && icompare(request.get("Upgrade", ""), "websocket") == 0)
144 {
145 std::string version = request.get("Sec-WebSocket-Version", "");
146 if (version.empty()) throw WebSocketException("Missing Sec-WebSocket-Version in handshake request", WS_ERR_HANDSHAKE_NO_VERSION);
147 if (version != WEBSOCKET_VERSION) throw WebSocketException("Unsupported WebSocket version requested", version, WS_ERR_HANDSHAKE_UNSUPPORTED_VERSION);
148 std::string key = request.get("Sec-WebSocket-Key", "");
149 Poco::trimInPlace(key);
150 if (key.empty()) throw WebSocketException("Missing Sec-WebSocket-Key in handshake request", WS_ERR_HANDSHAKE_NO_KEY);
151
152 response.setStatusAndReason(HTTPResponse::HTTP_SWITCHING_PROTOCOLS);
153 response.set("Upgrade", "websocket");
154 response.set("Connection", "Upgrade");
155 response.set("Sec-WebSocket-Accept", computeAccept(key));
156 response.setContentLength(HTTPResponse::UNKNOWN_CONTENT_LENGTH);
157 response.send().flush();
158
159 HTTPServerRequestImpl& requestImpl = static_cast<HTTPServerRequestImpl&>(request);
160 return new WebSocketImpl(static_cast<StreamSocketImpl*>(requestImpl.detachSocket().impl()), requestImpl.session(), false);
161 }
162 else throw WebSocketException("No WebSocket handshake", WS_ERR_NO_HANDSHAKE);
163}
164
165
166WebSocketImpl* WebSocket::connect(HTTPClientSession& cs, HTTPRequest& request, HTTPResponse& response, HTTPCredentials& credentials)
167{
168 if (!cs.getProxyHost().empty() && !cs.secure())
169 {
170 cs.proxyTunnel();
171 }
172 std::string key = createKey();
173 request.set("Connection", "Upgrade");
174 request.set("Upgrade", "websocket");
175 request.set("Sec-WebSocket-Version", WEBSOCKET_VERSION);
176 request.set("Sec-WebSocket-Key", key);
177 request.setChunkedTransferEncoding(false);
178 cs.setKeepAlive(true);
179 cs.sendRequest(request);
180 std::istream& istr = cs.receiveResponse(response);
181 if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS)
182 {
183 return completeHandshake(cs, response, key);
184 }
185 else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED)
186 {
187 Poco::NullOutputStream null;
188 Poco::StreamCopier::copyStream(istr, null);
189 credentials.authenticate(request, response);
190 if (!cs.getProxyHost().empty() && !cs.secure())
191 {
192 cs.reset();
193 cs.proxyTunnel();
194 }
195 cs.sendRequest(request);
196 cs.receiveResponse(response);
197 if (response.getStatus() == HTTPResponse::HTTP_SWITCHING_PROTOCOLS)
198 {
199 return completeHandshake(cs, response, key);
200 }
201 else if (response.getStatus() == HTTPResponse::HTTP_UNAUTHORIZED)
202 {
203 throw WebSocketException("Not authorized", WS_ERR_UNAUTHORIZED);
204 }
205 }
206 if (response.getStatus() == HTTPResponse::HTTP_OK)
207 {
208 throw WebSocketException("The server does not understand the WebSocket protocol", WS_ERR_NO_HANDSHAKE);
209 }
210 else
211 {
212 throw WebSocketException("Cannot upgrade to WebSocket connection", response.getReason(), WS_ERR_NO_HANDSHAKE);
213 }
214}
215
216
217WebSocketImpl* WebSocket::completeHandshake(HTTPClientSession& cs, HTTPResponse& response, const std::string& key)
218{
219 std::string connection = response.get("Connection", "");
220 if (Poco::icompare(connection, "Upgrade") != 0)
221 throw WebSocketException("No Connection: Upgrade header in handshake response", WS_ERR_NO_HANDSHAKE);
222 std::string upgrade = response.get("Upgrade", "");
223 if (Poco::icompare(upgrade, "websocket") != 0)
224 throw WebSocketException("No Upgrade: websocket header in handshake response", WS_ERR_NO_HANDSHAKE);
225 std::string accept = response.get("Sec-WebSocket-Accept", "");
226 if (accept != computeAccept(key))
227 throw WebSocketException("Invalid or missing Sec-WebSocket-Accept header in handshake response", WS_ERR_HANDSHAKE_ACCEPT);
228 return new WebSocketImpl(static_cast<StreamSocketImpl*>(cs.detachSocket().impl()), cs, true);
229}
230
231
232std::string WebSocket::createKey()
233{
234 Poco::Random rnd;
235 std::ostringstream ostr;
236 Poco::Base64Encoder base64(ostr);
237 Poco::BinaryWriter writer(base64);
238 writer << rnd.next() << rnd.next() << rnd.next() << rnd.next();
239 base64.close();
240 return ostr.str();
241}
242
243
244std::string WebSocket::computeAccept(const std::string& key)
245{
246 std::string accept(key);
247 accept += WEBSOCKET_GUID;
248 Poco::SHA1Engine sha1;
249 sha1.update(accept);
250 Poco::DigestEngine::Digest d = sha1.digest();
251 std::ostringstream ostr;
252 Poco::Base64Encoder base64(ostr);
253 base64.write(reinterpret_cast<const char*>(&d[0]), d.size());
254 base64.close();
255 return ostr.str();
256}
257
258
259} } // namespace Poco::Net
260