1// Copyright 2016 The SwiftShader Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "ValidateLimitations.h"
16#include "InfoSink.h"
17#include "InitializeParseContext.h"
18#include "ParseHelper.h"
19
20namespace {
21bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
22 for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
23 if (i->index.id == symbol->getId())
24 return true;
25 }
26 return false;
27}
28
29void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
30 for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
31 if (i->index.id == symbol->getId()) {
32 ASSERT(i->loop);
33 i->loop->setUnrollFlag(true);
34 return;
35 }
36 }
37 UNREACHABLE(0);
38}
39
40// Traverses a node to check if it represents a constant index expression.
41// Definition:
42// constant-index-expressions are a superset of constant-expressions.
43// Constant-index-expressions can include loop indices as defined in
44// GLSL ES 1.0 spec, Appendix A, section 4.
45// The following are constant-index-expressions:
46// - Constant expressions
47// - Loop indices as defined in section 4
48// - Expressions composed of both of the above
49class ValidateConstIndexExpr : public TIntermTraverser {
50public:
51 ValidateConstIndexExpr(const TLoopStack& stack)
52 : mValid(true), mLoopStack(stack) {}
53
54 // Returns true if the parsed node represents a constant index expression.
55 bool isValid() const { return mValid; }
56
57 virtual void visitSymbol(TIntermSymbol* symbol) {
58 // Only constants and loop indices are allowed in a
59 // constant index expression.
60 if (mValid) {
61 mValid = (symbol->getQualifier() == EvqConstExpr) ||
62 IsLoopIndex(symbol, mLoopStack);
63 }
64 }
65
66private:
67 bool mValid;
68 const TLoopStack& mLoopStack;
69};
70
71// Traverses a node to check if it uses a loop index.
72// If an int loop index is used in its body as a sampler array index,
73// mark the loop for unroll.
74class ValidateLoopIndexExpr : public TIntermTraverser {
75public:
76 ValidateLoopIndexExpr(TLoopStack& stack)
77 : mUsesFloatLoopIndex(false),
78 mUsesIntLoopIndex(false),
79 mLoopStack(stack) {}
80
81 bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
82 bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }
83
84 virtual void visitSymbol(TIntermSymbol* symbol) {
85 if (IsLoopIndex(symbol, mLoopStack)) {
86 switch (symbol->getBasicType()) {
87 case EbtFloat:
88 mUsesFloatLoopIndex = true;
89 break;
90 case EbtUInt:
91 mUsesIntLoopIndex = true;
92 MarkLoopForUnroll(symbol, mLoopStack);
93 break;
94 case EbtInt:
95 mUsesIntLoopIndex = true;
96 MarkLoopForUnroll(symbol, mLoopStack);
97 break;
98 default:
99 UNREACHABLE(symbol->getBasicType());
100 }
101 }
102 }
103
104private:
105 bool mUsesFloatLoopIndex;
106 bool mUsesIntLoopIndex;
107 TLoopStack& mLoopStack;
108};
109} // namespace
110
111ValidateLimitations::ValidateLimitations(GLenum shaderType,
112 TInfoSinkBase& sink)
113 : mShaderType(shaderType),
114 mSink(sink),
115 mNumErrors(0)
116{
117}
118
119bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
120{
121 // Check if loop index is modified in the loop body.
122 validateOperation(node, node->getLeft());
123
124 // Check indexing.
125 switch (node->getOp()) {
126 case EOpIndexDirect:
127 validateIndexing(node);
128 break;
129 case EOpIndexIndirect:
130 validateIndexing(node);
131 break;
132 default: break;
133 }
134 return true;
135}
136
137bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
138{
139 // Check if loop index is modified in the loop body.
140 validateOperation(node, node->getOperand());
141
142 return true;
143}
144
145bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
146{
147 switch (node->getOp()) {
148 case EOpFunctionCall:
149 validateFunctionCall(node);
150 break;
151 default:
152 break;
153 }
154 return true;
155}
156
157bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
158{
159 if (!validateLoopType(node))
160 return false;
161
162 TLoopInfo info;
163 memset(&info, 0, sizeof(TLoopInfo));
164 info.loop = node;
165 if (!validateForLoopHeader(node, &info))
166 return false;
167
168 TIntermNode* body = node->getBody();
169 if (body) {
170 mLoopStack.push_back(info);
171 body->traverse(this);
172 mLoopStack.pop_back();
173 }
174
175 // The loop is fully processed - no need to visit children.
176 return false;
177}
178
179void ValidateLimitations::error(TSourceLoc loc,
180 const char *reason, const char* token)
181{
182 mSink.prefix(EPrefixError);
183 mSink.location(loc);
184 mSink << "'" << token << "' : " << reason << "\n";
185 ++mNumErrors;
186}
187
188bool ValidateLimitations::withinLoopBody() const
189{
190 return !mLoopStack.empty();
191}
192
193bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
194{
195 return IsLoopIndex(symbol, mLoopStack);
196}
197
198bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
199 TLoopType type = node->getType();
200 if (type == ELoopFor)
201 return true;
202
203 // Reject while and do-while loops.
204 error(node->getLine(),
205 "This type of loop is not allowed",
206 type == ELoopWhile ? "while" : "do");
207 return false;
208}
209
210bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
211 TLoopInfo* info)
212{
213 ASSERT(node->getType() == ELoopFor);
214
215 //
216 // The for statement has the form:
217 // for ( init-declaration ; condition ; expression ) statement
218 //
219 if (!validateForLoopInit(node, info))
220 return false;
221 if (!validateForLoopCond(node, info))
222 return false;
223 if (!validateForLoopExpr(node, info))
224 return false;
225
226 return true;
227}
228
229bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
230 TLoopInfo* info)
231{
232 TIntermNode* init = node->getInit();
233 if (!init) {
234 error(node->getLine(), "Missing init declaration", "for");
235 return false;
236 }
237
238 //
239 // init-declaration has the form:
240 // type-specifier identifier = constant-expression
241 //
242 TIntermAggregate* decl = init->getAsAggregate();
243 if (!decl || (decl->getOp() != EOpDeclaration)) {
244 error(init->getLine(), "Invalid init declaration", "for");
245 return false;
246 }
247 // To keep things simple do not allow declaration list.
248 TIntermSequence& declSeq = decl->getSequence();
249 if (declSeq.size() != 1) {
250 error(decl->getLine(), "Invalid init declaration", "for");
251 return false;
252 }
253 TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
254 if (!declInit || (declInit->getOp() != EOpInitialize)) {
255 error(decl->getLine(), "Invalid init declaration", "for");
256 return false;
257 }
258 TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
259 if (!symbol) {
260 error(declInit->getLine(), "Invalid init declaration", "for");
261 return false;
262 }
263 // The loop index has type int or float.
264 TBasicType type = symbol->getBasicType();
265 if (!IsInteger(type) && (type != EbtFloat)) {
266 error(symbol->getLine(),
267 "Invalid type for loop index", getBasicString(type));
268 return false;
269 }
270 // The loop index is initialized with constant expression.
271 if (!isConstExpr(declInit->getRight())) {
272 error(declInit->getLine(),
273 "Loop index cannot be initialized with non-constant expression",
274 symbol->getSymbol().c_str());
275 return false;
276 }
277
278 info->index.id = symbol->getId();
279 return true;
280}
281
282bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
283 TLoopInfo* info)
284{
285 TIntermNode* cond = node->getCondition();
286 if (!cond) {
287 error(node->getLine(), "Missing condition", "for");
288 return false;
289 }
290 //
291 // condition has the form:
292 // loop_index relational_operator constant_expression
293 //
294 TIntermBinary* binOp = cond->getAsBinaryNode();
295 if (!binOp) {
296 error(node->getLine(), "Invalid condition", "for");
297 return false;
298 }
299 // Loop index should be to the left of relational operator.
300 TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
301 if (!symbol) {
302 error(binOp->getLine(), "Invalid condition", "for");
303 return false;
304 }
305 if (symbol->getId() != info->index.id) {
306 error(symbol->getLine(),
307 "Expected loop index", symbol->getSymbol().c_str());
308 return false;
309 }
310 // Relational operator is one of: > >= < <= == or !=.
311 switch (binOp->getOp()) {
312 case EOpEqual:
313 case EOpNotEqual:
314 case EOpLessThan:
315 case EOpGreaterThan:
316 case EOpLessThanEqual:
317 case EOpGreaterThanEqual:
318 break;
319 default:
320 error(binOp->getLine(),
321 "Invalid relational operator",
322 getOperatorString(binOp->getOp()));
323 break;
324 }
325 // Loop index must be compared with a constant.
326 if (!isConstExpr(binOp->getRight())) {
327 error(binOp->getLine(),
328 "Loop index cannot be compared with non-constant expression",
329 symbol->getSymbol().c_str());
330 return false;
331 }
332
333 return true;
334}
335
336bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
337 TLoopInfo* info)
338{
339 TIntermNode* expr = node->getExpression();
340 if (!expr) {
341 error(node->getLine(), "Missing expression", "for");
342 return false;
343 }
344
345 // for expression has one of the following forms:
346 // loop_index++
347 // loop_index--
348 // loop_index += constant_expression
349 // loop_index -= constant_expression
350 // ++loop_index
351 // --loop_index
352 // The last two forms are not specified in the spec, but I am assuming
353 // its an oversight.
354 TIntermUnary* unOp = expr->getAsUnaryNode();
355 TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode();
356
357 TOperator op = EOpNull;
358 TIntermSymbol* symbol = nullptr;
359 if (unOp) {
360 op = unOp->getOp();
361 symbol = unOp->getOperand()->getAsSymbolNode();
362 } else if (binOp) {
363 op = binOp->getOp();
364 symbol = binOp->getLeft()->getAsSymbolNode();
365 }
366
367 // The operand must be loop index.
368 if (!symbol) {
369 error(expr->getLine(), "Invalid expression", "for");
370 return false;
371 }
372 if (symbol->getId() != info->index.id) {
373 error(symbol->getLine(),
374 "Expected loop index", symbol->getSymbol().c_str());
375 return false;
376 }
377
378 // The operator is one of: ++ -- += -=.
379 switch (op) {
380 case EOpPostIncrement:
381 case EOpPostDecrement:
382 case EOpPreIncrement:
383 case EOpPreDecrement:
384 ASSERT((unOp != NULL) && (binOp == NULL));
385 break;
386 case EOpAddAssign:
387 case EOpSubAssign:
388 ASSERT((unOp == NULL) && (binOp != NULL));
389 break;
390 default:
391 error(expr->getLine(), "Invalid operator", getOperatorString(op));
392 return false;
393 }
394
395 // Loop index must be incremented/decremented with a constant.
396 if (binOp != NULL) {
397 if (!isConstExpr(binOp->getRight())) {
398 error(binOp->getLine(),
399 "Loop index cannot be modified by non-constant expression",
400 symbol->getSymbol().c_str());
401 return false;
402 }
403 }
404
405 return true;
406}
407
408bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
409{
410 ASSERT(node->getOp() == EOpFunctionCall);
411
412 // If not within loop body, there is nothing to check.
413 if (!withinLoopBody())
414 return true;
415
416 // List of param indices for which loop indices are used as argument.
417 typedef std::vector<int> ParamIndex;
418 ParamIndex pIndex;
419 TIntermSequence& params = node->getSequence();
420 for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
421 TIntermSymbol* symbol = params[i]->getAsSymbolNode();
422 if (symbol && isLoopIndex(symbol))
423 pIndex.push_back(i);
424 }
425 // If none of the loop indices are used as arguments,
426 // there is nothing to check.
427 if (pIndex.empty())
428 return true;
429
430 bool valid = true;
431 TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
432 TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion());
433 ASSERT(symbol && symbol->isFunction());
434 TFunction* function = static_cast<TFunction*>(symbol);
435 for (ParamIndex::const_iterator i = pIndex.begin();
436 i != pIndex.end(); ++i) {
437 const TParameter& param = function->getParam(*i);
438 TQualifier qual = param.type->getQualifier();
439 if ((qual == EvqOut) || (qual == EvqInOut)) {
440 error(params[*i]->getLine(),
441 "Loop index cannot be used as argument to a function out or inout parameter",
442 params[*i]->getAsSymbolNode()->getSymbol().c_str());
443 valid = false;
444 }
445 }
446
447 return valid;
448}
449
450bool ValidateLimitations::validateOperation(TIntermOperator* node,
451 TIntermNode* operand) {
452 // Check if loop index is modified in the loop body.
453 if (!withinLoopBody() || !node->modifiesState())
454 return true;
455
456 const TIntermSymbol* symbol = operand->getAsSymbolNode();
457 if (symbol && isLoopIndex(symbol)) {
458 error(node->getLine(),
459 "Loop index cannot be statically assigned to within the body of the loop",
460 symbol->getSymbol().c_str());
461 }
462 return true;
463}
464
465bool ValidateLimitations::isConstExpr(TIntermNode* node)
466{
467 ASSERT(node);
468 return node->getAsConstantUnion() != nullptr;
469}
470
471bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
472{
473 ASSERT(node);
474
475 ValidateConstIndexExpr validate(mLoopStack);
476 node->traverse(&validate);
477 return validate.isValid();
478}
479
480bool ValidateLimitations::validateIndexing(TIntermBinary* node)
481{
482 ASSERT((node->getOp() == EOpIndexDirect) ||
483 (node->getOp() == EOpIndexIndirect));
484
485 bool valid = true;
486 TIntermTyped* index = node->getRight();
487 // The index expression must have integral type.
488 if (!index->isScalarInt()) {
489 error(index->getLine(),
490 "Index expression must have integral type",
491 index->getCompleteString().c_str());
492 valid = false;
493 }
494 // The index expession must be a constant-index-expression unless
495 // the operand is a uniform in a vertex shader.
496 TIntermTyped* operand = node->getLeft();
497 bool skip = (mShaderType == GL_VERTEX_SHADER) &&
498 (operand->getQualifier() == EvqUniform);
499 if (!skip && !isConstIndexExpr(index)) {
500 error(index->getLine(), "Index expression must be constant", "[]");
501 valid = false;
502 }
503 return valid;
504}
505
506