1//
2// WebSocketTest.cpp
3//
4// Copyright (c) 2012, Applied Informatics Software Engineering GmbH.
5// and Contributors.
6//
7// SPDX-License-Identifier: BSL-1.0
8//
9
10
11#include "WebSocketTest.h"
12#include "Poco/CppUnit/TestCaller.h"
13#include "Poco/CppUnit/TestSuite.h"
14#include "Poco/Net/WebSocket.h"
15#include "Poco/Net/SocketStream.h"
16#include "Poco/Net/HTTPSClientSession.h"
17#include "Poco/Net/HTTPServer.h"
18#include "Poco/Net/HTTPServerParams.h"
19#include "Poco/Net/HTTPRequestHandler.h"
20#include "Poco/Net/HTTPRequestHandlerFactory.h"
21#include "Poco/Net/HTTPServerRequest.h"
22#include "Poco/Net/HTTPServerResponse.h"
23#include "Poco/Net/SecureServerSocket.h"
24#include "Poco/Net/NetException.h"
25#include "Poco/Thread.h"
26#include <iostream>
27
28using Poco::Net::HTTPSClientSession;
29using Poco::Net::HTTPRequest;
30using Poco::Net::HTTPResponse;
31using Poco::Net::HTTPServerRequest;
32using Poco::Net::HTTPServerResponse;
33using Poco::Net::SocketStream;
34using Poco::Net::WebSocket;
35using Poco::Net::WebSocketException;
36
37
38namespace
39{
40 class WebSocketRequestHandler: public Poco::Net::HTTPRequestHandler
41 {
42 public:
43 WebSocketRequestHandler(std::size_t bufSize = 1024): _bufSize(bufSize)
44 {
45 }
46
47 void handleRequest(HTTPServerRequest& request, HTTPServerResponse& response)
48 {
49 try
50 {
51 WebSocket ws(request, response);
52 std::unique_ptr<char[]> pBuffer(new char[_bufSize]);
53 int flags;
54 int n;
55 do
56 {
57 n = ws.receiveFrame(pBuffer.get(), static_cast<int>(_bufSize), flags);
58 if (n == 0)
59 break;
60 ws.sendFrame(pBuffer.get(), n, flags);
61 }
62 while ((flags & WebSocket::FRAME_OP_BITMASK) != WebSocket::FRAME_OP_CLOSE);
63 }
64 catch (WebSocketException& exc)
65 {
66 switch (exc.code())
67 {
68 case WebSocket::WS_ERR_HANDSHAKE_UNSUPPORTED_VERSION:
69 response.set("Sec-WebSocket-Version", WebSocket::WEBSOCKET_VERSION);
70 // fallthrough
71 case WebSocket::WS_ERR_NO_HANDSHAKE:
72 case WebSocket::WS_ERR_HANDSHAKE_NO_VERSION:
73 case WebSocket::WS_ERR_HANDSHAKE_NO_KEY:
74 response.setStatusAndReason(HTTPResponse::HTTP_BAD_REQUEST);
75 response.setContentLength(0);
76 response.send();
77 break;
78 }
79 }
80 }
81
82 private:
83 std::size_t _bufSize;
84 };
85
86 class WebSocketRequestHandlerFactory: public Poco::Net::HTTPRequestHandlerFactory
87 {
88 public:
89 WebSocketRequestHandlerFactory(std::size_t bufSize = 1024): _bufSize(bufSize)
90 {
91 }
92
93 Poco::Net::HTTPRequestHandler* createRequestHandler(const HTTPServerRequest& request)
94 {
95 return new WebSocketRequestHandler(_bufSize);
96 }
97
98 private:
99 std::size_t _bufSize;
100 };
101}
102
103
104WebSocketTest::WebSocketTest(const std::string& name): CppUnit::TestCase(name)
105{
106}
107
108
109WebSocketTest::~WebSocketTest()
110{
111}
112
113
114void WebSocketTest::testWebSocket()
115{
116 Poco::Net::SecureServerSocket ss(0);
117 Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory, ss, new Poco::Net::HTTPServerParams);
118 server.start();
119
120 Poco::Thread::sleep(200);
121
122 HTTPSClientSession cs("127.0.0.1", ss.address().port());
123 HTTPRequest request(HTTPRequest::HTTP_GET, "/ws");
124 HTTPResponse response;
125 WebSocket ws(cs, request, response);
126
127 std::string payload("x");
128 ws.sendFrame(payload.data(), (int) payload.size());
129 char buffer[1024] = {};
130 int flags;
131 int n = ws.receiveFrame(buffer, sizeof(buffer), flags);
132 assertTrue (n == payload.size());
133 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
134 assertTrue (flags == WebSocket::FRAME_TEXT);
135
136 for (int i = 2; i < 20; i++)
137 {
138 payload.assign(i, 'x');
139 ws.sendFrame(payload.data(), (int) payload.size());
140 n = ws.receiveFrame(buffer, sizeof(buffer), flags);
141 assertTrue (n == payload.size());
142 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
143 assertTrue (flags == WebSocket::FRAME_TEXT);
144 }
145
146 for (int i = 125; i < 129; i++)
147 {
148 payload.assign(i, 'x');
149 ws.sendFrame(payload.data(), (int) payload.size());
150 n = ws.receiveFrame(buffer, sizeof(buffer), flags);
151 assertTrue (n == payload.size());
152 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
153 assertTrue (flags == WebSocket::FRAME_TEXT);
154 }
155
156 payload = "Hello, world!";
157 ws.sendFrame(payload.data(), (int) payload.size());
158 n = ws.receiveFrame(buffer, sizeof(buffer), flags);
159 assertTrue (n == payload.size());
160 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
161 assertTrue (flags == WebSocket::FRAME_TEXT);
162
163 payload = "Hello, universe!";
164 ws.sendFrame(payload.data(), (int) payload.size(), WebSocket::FRAME_BINARY);
165 n = ws.receiveFrame(buffer, sizeof(buffer), flags);
166 assertTrue (n == payload.size());
167 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
168 assertTrue (flags == WebSocket::FRAME_BINARY);
169
170 ws.shutdown();
171 n = ws.receiveFrame(buffer, sizeof(buffer), flags);
172 assertTrue (n == 2);
173 assertTrue ((flags & WebSocket::FRAME_OP_BITMASK) == WebSocket::FRAME_OP_CLOSE);
174
175 server.stop();
176}
177
178
179void WebSocketTest::testWebSocketLarge()
180{
181 const int msgSize = 64000;
182
183 Poco::Net::SecureServerSocket ss(0);
184 Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(msgSize), ss, new Poco::Net::HTTPServerParams);
185 server.start();
186
187 Poco::Thread::sleep(200);
188
189 HTTPSClientSession cs("127.0.0.1", ss.address().port());
190 HTTPRequest request(HTTPRequest::HTTP_GET, "/ws");
191 HTTPResponse response;
192 WebSocket ws(cs, request, response);
193 ws.setSendBufferSize(msgSize);
194 ws.setReceiveBufferSize(msgSize);
195 std::string payload(msgSize, 'x');
196 SocketStream sstr(ws);
197 sstr << payload;
198 sstr.flush();
199
200 char buffer[msgSize + 1] = {};
201 int flags;
202 int n = 0;
203 do
204 {
205 n += ws.receiveFrame(buffer + n, sizeof(buffer) - n, flags);
206 } while (n > 0 && n < msgSize);
207
208 assertTrue (n == payload.size());
209 assertTrue (payload.compare(0, payload.size(), buffer, 0, n) == 0);
210
211 server.stop();
212}
213
214
215void WebSocketTest::setUp()
216{
217}
218
219
220void WebSocketTest::tearDown()
221{
222}
223
224
225CppUnit::Test* WebSocketTest::suite()
226{
227 CppUnit::TestSuite* pSuite = new CppUnit::TestSuite("WebSocketTest");
228
229 CppUnit_addTest(pSuite, WebSocketTest, testWebSocket);
230 CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge);
231
232 return pSuite;
233}
234