1// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors
2// Licensed under the MIT License:
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20// THE SOFTWARE.
21
22#include "any.h"
23
24#include <kj/debug.h>
25
26#if !CAPNP_LITE
27#include "capability.h"
28#endif // !CAPNP_LITE
29
30namespace capnp {
31
32#if !CAPNP_LITE
33
34kj::Own<ClientHook> PipelineHook::getPipelinedCap(kj::Array<PipelineOp>&& ops) {
35 return getPipelinedCap(ops.asPtr());
36}
37
38kj::Own<ClientHook> AnyPointer::Reader::getPipelinedCap(
39 kj::ArrayPtr<const PipelineOp> ops) const {
40 _::PointerReader pointer = reader;
41
42 for (auto& op: ops) {
43 switch (op.type) {
44 case PipelineOp::Type::NOOP:
45 break;
46
47 case PipelineOp::Type::GET_POINTER_FIELD:
48 pointer = pointer.getStruct(nullptr).getPointerField(bounded(op.pointerIndex) * POINTERS);
49 break;
50 }
51 }
52
53 return pointer.getCapability();
54}
55
56AnyPointer::Pipeline AnyPointer::Pipeline::noop() {
57 auto newOps = kj::heapArray<PipelineOp>(ops.size());
58 for (auto i: kj::indices(ops)) {
59 newOps[i] = ops[i];
60 }
61 return Pipeline(hook->addRef(), kj::mv(newOps));
62}
63
64AnyPointer::Pipeline AnyPointer::Pipeline::getPointerField(uint16_t pointerIndex) {
65 auto newOps = kj::heapArray<PipelineOp>(ops.size() + 1);
66 for (auto i: kj::indices(ops)) {
67 newOps[i] = ops[i];
68 }
69 auto& newOp = newOps[ops.size()];
70 newOp.type = PipelineOp::GET_POINTER_FIELD;
71 newOp.pointerIndex = pointerIndex;
72
73 return Pipeline(hook->addRef(), kj::mv(newOps));
74}
75
76kj::Own<ClientHook> AnyPointer::Pipeline::asCap() {
77 return hook->getPipelinedCap(ops);
78}
79
80#endif // !CAPNP_LITE
81
82Equality AnyStruct::Reader::equals(AnyStruct::Reader right) const {
83 auto dataL = getDataSection();
84 size_t dataSizeL = dataL.size();
85 while(dataSizeL > 0 && dataL[dataSizeL - 1] == 0) {
86 -- dataSizeL;
87 }
88
89 auto dataR = right.getDataSection();
90 size_t dataSizeR = dataR.size();
91 while(dataSizeR > 0 && dataR[dataSizeR - 1] == 0) {
92 -- dataSizeR;
93 }
94
95 if(dataSizeL != dataSizeR) {
96 return Equality::NOT_EQUAL;
97 }
98
99 if(0 != memcmp(dataL.begin(), dataR.begin(), dataSizeL)) {
100 return Equality::NOT_EQUAL;
101 }
102
103 auto ptrsL = getPointerSection();
104 size_t ptrsSizeL = ptrsL.size();
105 while (ptrsSizeL > 0 && ptrsL[ptrsSizeL - 1].isNull()) {
106 -- ptrsSizeL;
107 }
108
109 auto ptrsR = right.getPointerSection();
110 size_t ptrsSizeR = ptrsR.size();
111 while (ptrsSizeR > 0 && ptrsR[ptrsSizeR - 1].isNull()) {
112 -- ptrsSizeR;
113 }
114
115 if(ptrsSizeL != ptrsSizeR) {
116 return Equality::NOT_EQUAL;
117 }
118
119 size_t i = 0;
120
121 auto eqResult = Equality::EQUAL;
122 for (; i < ptrsSizeL; i++) {
123 auto l = ptrsL[i];
124 auto r = ptrsR[i];
125 switch(l.equals(r)) {
126 case Equality::EQUAL:
127 break;
128 case Equality::NOT_EQUAL:
129 return Equality::NOT_EQUAL;
130 case Equality::UNKNOWN_CONTAINS_CAPS:
131 eqResult = Equality::UNKNOWN_CONTAINS_CAPS;
132 break;
133 default:
134 KJ_UNREACHABLE;
135 }
136 }
137
138 return eqResult;
139}
140
141kj::StringPtr KJ_STRINGIFY(Equality res) {
142 switch(res) {
143 case Equality::NOT_EQUAL:
144 return "NOT_EQUAL";
145 case Equality::EQUAL:
146 return "EQUAL";
147 case Equality::UNKNOWN_CONTAINS_CAPS:
148 return "UNKNOWN_CONTAINS_CAPS";
149 }
150 KJ_UNREACHABLE;
151}
152
153Equality AnyList::Reader::equals(AnyList::Reader right) const {
154 if(size() != right.size()) {
155 return Equality::NOT_EQUAL;
156 }
157
158 if (getElementSize() != right.getElementSize()) {
159 return Equality::NOT_EQUAL;
160 }
161
162 auto eqResult = Equality::EQUAL;
163 switch(getElementSize()) {
164 case ElementSize::VOID:
165 case ElementSize::BIT:
166 case ElementSize::BYTE:
167 case ElementSize::TWO_BYTES:
168 case ElementSize::FOUR_BYTES:
169 case ElementSize::EIGHT_BYTES: {
170 size_t cmpSize = getRawBytes().size();
171
172 if (getElementSize() == ElementSize::BIT && size() % 8 != 0) {
173 // The list does not end on a byte boundary. We need special handling for the final
174 // byte because we only care about the bits that are actually elements of the list.
175
176 uint8_t mask = (1 << (size() % 8)) - 1; // lowest size() bits set
177 if ((getRawBytes()[cmpSize - 1] & mask) != (right.getRawBytes()[cmpSize - 1] & mask)) {
178 return Equality::NOT_EQUAL;
179 }
180 cmpSize -= 1;
181 }
182
183 if (memcmp(getRawBytes().begin(), right.getRawBytes().begin(), cmpSize) == 0) {
184 return Equality::EQUAL;
185 } else {
186 return Equality::NOT_EQUAL;
187 }
188 }
189 case ElementSize::POINTER:
190 case ElementSize::INLINE_COMPOSITE: {
191 auto llist = as<List<AnyStruct>>();
192 auto rlist = right.as<List<AnyStruct>>();
193 for(size_t i = 0; i < size(); i++) {
194 switch(llist[i].equals(rlist[i])) {
195 case Equality::EQUAL:
196 break;
197 case Equality::NOT_EQUAL:
198 return Equality::NOT_EQUAL;
199 case Equality::UNKNOWN_CONTAINS_CAPS:
200 eqResult = Equality::UNKNOWN_CONTAINS_CAPS;
201 break;
202 default:
203 KJ_UNREACHABLE;
204 }
205 }
206 return eqResult;
207 }
208 }
209 KJ_UNREACHABLE;
210}
211
212Equality AnyPointer::Reader::equals(AnyPointer::Reader right) const {
213 if(getPointerType() != right.getPointerType()) {
214 return Equality::NOT_EQUAL;
215 }
216 switch(getPointerType()) {
217 case PointerType::NULL_:
218 return Equality::EQUAL;
219 case PointerType::STRUCT:
220 return getAs<AnyStruct>().equals(right.getAs<AnyStruct>());
221 case PointerType::LIST:
222 return getAs<AnyList>().equals(right.getAs<AnyList>());
223 case PointerType::CAPABILITY:
224 return Equality::UNKNOWN_CONTAINS_CAPS;
225 }
226 // There aren't currently any other types of pointers
227 KJ_UNREACHABLE;
228}
229
230bool AnyPointer::Reader::operator==(AnyPointer::Reader right) const {
231 switch(equals(right)) {
232 case Equality::EQUAL:
233 return true;
234 case Equality::NOT_EQUAL:
235 return false;
236 case Equality::UNKNOWN_CONTAINS_CAPS:
237 KJ_FAIL_REQUIRE(
238 "operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
239 }
240 KJ_UNREACHABLE;
241}
242
243bool AnyStruct::Reader::operator==(AnyStruct::Reader right) const {
244 switch(equals(right)) {
245 case Equality::EQUAL:
246 return true;
247 case Equality::NOT_EQUAL:
248 return false;
249 case Equality::UNKNOWN_CONTAINS_CAPS:
250 KJ_FAIL_REQUIRE(
251 "operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
252 }
253 KJ_UNREACHABLE;
254}
255
256bool AnyList::Reader::operator==(AnyList::Reader right) const {
257 switch(equals(right)) {
258 case Equality::EQUAL:
259 return true;
260 case Equality::NOT_EQUAL:
261 return false;
262 case Equality::UNKNOWN_CONTAINS_CAPS:
263 KJ_FAIL_REQUIRE(
264 "operator== cannot determine equality of capabilities; use equals() instead if you need to handle this case");
265 }
266 KJ_UNREACHABLE;
267}
268
269} // namespace capnp
270