1#include "all.h"
2
3#define MASK(w) (BIT(8*(w)-1)*2-1) /* must work when w==8 */
4
5typedef struct Loc Loc;
6typedef struct Slice Slice;
7typedef struct Insert Insert;
8
9
10struct Loc {
11 enum {
12 LRoot, /* right above the original load */
13 LLoad, /* inserting a load is allowed */
14 LNoLoad, /* only scalar operations allowed */
15 } type;
16 uint off;
17 Blk *blk;
18};
19
20struct Slice {
21 Ref ref;
22 short sz;
23 short cls; /* load class */
24};
25
26struct Insert {
27 uint isphi:1;
28 uint num:31;
29 uint bid;
30 uint off;
31 union {
32 Ins ins;
33 struct {
34 Slice m;
35 Phi *p;
36 } phi;
37 } new;
38};
39
40static Fn *curf;
41static uint inum; /* current insertion number */
42static Insert *ilog; /* global insertion log */
43static uint nlog; /* number of entries in the log */
44
45int
46loadsz(Ins *l)
47{
48 switch (l->op) {
49 case Oloadsb: case Oloadub: return 1;
50 case Oloadsh: case Oloaduh: return 2;
51 case Oloadsw: case Oloaduw: return 4;
52 case Oload: return KWIDE(l->cls) ? 8 : 4;
53 }
54 die("unreachable");
55}
56
57int
58storesz(Ins *s)
59{
60 switch (s->op) {
61 case Ostoreb: return 1;
62 case Ostoreh: return 2;
63 case Ostorew: case Ostores: return 4;
64 case Ostorel: case Ostored: return 8;
65 }
66 die("unreachable");
67}
68
69static Ref
70iins(int cls, int op, Ref a0, Ref a1, Loc *l)
71{
72 Insert *ist;
73
74 vgrow(&ilog, ++nlog);
75 ist = &ilog[nlog-1];
76 ist->isphi = 0;
77 ist->num = inum++;
78 ist->bid = l->blk->id;
79 ist->off = l->off;
80 ist->new.ins = (Ins){op, cls, R, {a0, a1}};
81 return ist->new.ins.to = newtmp("ld", cls, curf);
82}
83
84static void
85cast(Ref *r, int cls, Loc *l)
86{
87 int cls0;
88
89 if (rtype(*r) == RCon)
90 return;
91 assert(rtype(*r) == RTmp);
92 cls0 = curf->tmp[r->val].cls;
93 if (cls0 == cls || (cls == Kw && cls0 == Kl))
94 return;
95 if (KWIDE(cls0) < KWIDE(cls)) {
96 if (cls0 == Ks)
97 *r = iins(Kw, Ocast, *r, R, l);
98 *r = iins(Kl, Oextuw, *r, R, l);
99 if (cls == Kd)
100 *r = iins(Kd, Ocast, *r, R, l);
101 } else {
102 if (cls0 == Kd && cls != Kl)
103 *r = iins(Kl, Ocast, *r, R, l);
104 if (cls0 != Kd || cls != Kw)
105 *r = iins(cls, Ocast, *r, R, l);
106 }
107}
108
109static inline void
110mask(int cls, Ref *r, bits msk, Loc *l)
111{
112 cast(r, cls, l);
113 *r = iins(cls, Oand, *r, getcon(msk, curf), l);
114}
115
116static Ref
117load(Slice sl, bits msk, Loc *l)
118{
119 Alias *a;
120 Ref r, r1;
121 int ld, cls, all;
122 Con c;
123
124 ld = (int[]){
125 [1] = Oloadub,
126 [2] = Oloaduh,
127 [4] = Oloaduw,
128 [8] = Oload
129 }[sl.sz];
130 all = msk == MASK(sl.sz);
131 if (all)
132 cls = sl.cls;
133 else
134 cls = sl.sz > 4 ? Kl : Kw;
135 r = sl.ref;
136 /* sl.ref might not be live here,
137 * but its alias base ref will be
138 * (see killsl() below) */
139 if (rtype(r) == RTmp) {
140 a = &curf->tmp[r.val].alias;
141 switch (a->type) {
142 default:
143 die("unreachable");
144 case ALoc:
145 case AEsc:
146 case AUnk:
147 r = a->base;
148 if (!a->offset)
149 break;
150 r1 = getcon(a->offset, curf);
151 r = iins(Kl, Oadd, r, r1, l);
152 break;
153 case ACon:
154 case ASym:
155 c.type = CAddr;
156 c.label = a->label;
157 c.bits.i = a->offset;
158 c.local = 0;
159 r = newcon(&c, curf);
160 break;
161 }
162 }
163 r = iins(cls, ld, r, R, l);
164 if (!all)
165 mask(cls, &r, msk, l);
166 return r;
167}
168
169static int
170killsl(Ref r, Slice sl)
171{
172 Alias *a;
173
174 if (rtype(sl.ref) != RTmp)
175 return 0;
176 a = &curf->tmp[sl.ref.val].alias;
177 switch (a->type) {
178 default: die("unreachable");
179 case ALoc:
180 case AEsc:
181 case AUnk: return req(a->base, r);
182 case ACon:
183 case ASym: return 0;
184 }
185}
186
187/* returns a ref containing the contents of the slice
188 * passed as argument, all the bits set to 0 in the
189 * mask argument are zeroed in the result;
190 * the returned ref has an integer class when the
191 * mask does not cover all the bits of the slice,
192 * otherwise, it has class sl.cls
193 * the procedure returns R when it fails */
194static Ref
195def(Slice sl, bits msk, Blk *b, Ins *i, Loc *il)
196{
197 Blk *bp;
198 bits msk1, msks;
199 int off, cls, cls1, op, sz, ld;
200 uint np, oldl, oldt;
201 Ref r, r1;
202 Phi *p;
203 Insert *ist;
204 Loc l;
205
206 /* invariants:
207 * -1- b dominates il->blk; so we can use
208 * temporaries of b in il->blk
209 * -2- if il->type != LNoLoad, then il->blk
210 * postdominates the original load; so it
211 * is safe to load in il->blk
212 * -3- if il->type != LNoLoad, then b
213 * postdominates il->blk (and by 2, the
214 * original load)
215 */
216 assert(dom(b, il->blk));
217 oldl = nlog;
218 oldt = curf->ntmp;
219 if (0) {
220 Load:
221 curf->ntmp = oldt;
222 nlog = oldl;
223 if (il->type != LLoad)
224 return R;
225 return load(sl, msk, il);
226 }
227
228 if (!i)
229 i = &b->ins[b->nins];
230 cls = sl.sz > 4 ? Kl : Kw;
231 msks = MASK(sl.sz);
232
233 while (i > b->ins) {
234 --i;
235 if (killsl(i->to, sl)
236 || (i->op == Ocall && escapes(sl.ref, curf)))
237 goto Load;
238 ld = isload(i->op);
239 if (ld) {
240 sz = loadsz(i);
241 r1 = i->arg[0];
242 r = i->to;
243 } else if (isstore(i->op)) {
244 sz = storesz(i);
245 r1 = i->arg[1];
246 r = i->arg[0];
247 } else
248 continue;
249 switch (alias(sl.ref, sl.sz, r1, sz, &off, curf)) {
250 case MustAlias:
251 if (off < 0) {
252 off = -off;
253 msk1 = (MASK(sz) << 8*off) & msks;
254 op = Oshl;
255 } else {
256 msk1 = (MASK(sz) >> 8*off) & msks;
257 op = Oshr;
258 }
259 if ((msk1 & msk) == 0)
260 break;
261 if (off) {
262 cls1 = cls;
263 if (op == Oshr && off + sl.sz > 4)
264 cls1 = Kl;
265 cast(&r, cls1, il);
266 r1 = getcon(8*off, curf);
267 r = iins(cls1, op, r, r1, il);
268 }
269 if ((msk1 & msk) != msk1 || off + sz < sl.sz)
270 mask(cls, &r, msk1 & msk, il);
271 if ((msk & ~msk1) != 0) {
272 r1 = def(sl, msk & ~msk1, b, i, il);
273 if (req(r1, R))
274 goto Load;
275 r = iins(cls, Oor, r, r1, il);
276 }
277 if (msk == msks)
278 cast(&r, sl.cls, il);
279 return r;
280 case MayAlias:
281 if (ld)
282 break;
283 else
284 goto Load;
285 case NoAlias:
286 break;
287 default:
288 die("unreachable");
289 }
290 }
291
292 for (ist=ilog; ist<&ilog[nlog]; ++ist)
293 if (ist->isphi && ist->bid == b->id)
294 if (req(ist->new.phi.m.ref, sl.ref))
295 if (ist->new.phi.m.sz == sl.sz) {
296 r = ist->new.phi.p->to;
297 if (msk != msks)
298 mask(cls, &r, msk, il);
299 else
300 cast(&r, sl.cls, il);
301 return r;
302 }
303
304 for (p=b->phi; p; p=p->link)
305 if (killsl(p->to, sl))
306 /* scanning predecessors in that
307 * case would be unsafe */
308 goto Load;
309
310 if (b->npred == 0)
311 goto Load;
312 if (b->npred == 1) {
313 bp = b->pred[0];
314 assert(bp->loop >= il->blk->loop);
315 l = *il;
316 if (bp->s2)
317 l.type = LNoLoad;
318 r1 = def(sl, msk, bp, 0, &l);
319 if (req(r1, R))
320 goto Load;
321 return r1;
322 }
323
324 r = newtmp("ld", sl.cls, curf);
325 p = alloc(sizeof *p);
326 vgrow(&ilog, ++nlog);
327 ist = &ilog[nlog-1];
328 ist->isphi = 1;
329 ist->bid = b->id;
330 ist->new.phi.m = sl;
331 ist->new.phi.p = p;
332 p->to = r;
333 p->cls = sl.cls;
334 p->narg = b->npred;
335 p->arg = vnew(p->narg, sizeof p->arg[0], Pfn);
336 p->blk = vnew(p->narg, sizeof p->blk[0], Pfn);
337 for (np=0; np<b->npred; ++np) {
338 bp = b->pred[np];
339 if (!bp->s2
340 && il->type != LNoLoad
341 && bp->loop < il->blk->loop)
342 l.type = LLoad;
343 else
344 l.type = LNoLoad;
345 l.blk = bp;
346 l.off = bp->nins;
347 r1 = def(sl, msks, bp, 0, &l);
348 if (req(r1, R))
349 goto Load;
350 p->arg[np] = r1;
351 p->blk[np] = bp;
352 }
353 if (msk != msks)
354 mask(cls, &r, msk, il);
355 return r;
356}
357
358static int
359icmp(const void *pa, const void *pb)
360{
361 Insert *a, *b;
362 int c;
363
364 a = (Insert *)pa;
365 b = (Insert *)pb;
366 if ((c = a->bid - b->bid))
367 return c;
368 if (a->isphi && b->isphi)
369 return 0;
370 if (a->isphi)
371 return -1;
372 if (b->isphi)
373 return +1;
374 if ((c = a->off - b->off))
375 return c;
376 return a->num - b->num;
377}
378
379/* require rpo ssa alias */
380void
381loadopt(Fn *fn)
382{
383 Ins *i, *ib;
384 Blk *b;
385 int sz;
386 uint n, ni, ext, nt;
387 Insert *ist;
388 Slice sl;
389 Loc l;
390
391 curf = fn;
392 ilog = vnew(0, sizeof ilog[0], Pheap);
393 nlog = 0;
394 inum = 0;
395 for (b=fn->start; b; b=b->link)
396 for (i=b->ins; i<&b->ins[b->nins]; ++i) {
397 if (!isload(i->op))
398 continue;
399 sz = loadsz(i);
400 sl = (Slice){i->arg[0], sz, i->cls};
401 l = (Loc){LRoot, i-b->ins, b};
402 i->arg[1] = def(sl, MASK(sz), b, i, &l);
403 }
404 qsort(ilog, nlog, sizeof ilog[0], icmp);
405 vgrow(&ilog, nlog+1);
406 ilog[nlog].bid = fn->nblk; /* add a sentinel */
407 ib = vnew(0, sizeof(Ins), Pheap);
408 for (ist=ilog, n=0; n<fn->nblk; ++n) {
409 b = fn->rpo[n];
410 for (; ist->bid == n && ist->isphi; ++ist) {
411 ist->new.phi.p->link = b->phi;
412 b->phi = ist->new.phi.p;
413 }
414 ni = 0;
415 nt = 0;
416 for (;;) {
417 if (ist->bid == n && ist->off == ni)
418 i = &ist++->new.ins;
419 else {
420 if (ni == b->nins)
421 break;
422 i = &b->ins[ni++];
423 if (isload(i->op)
424 && !req(i->arg[1], R)) {
425 ext = Oextsb + i->op - Oloadsb;
426 switch (i->op) {
427 default:
428 die("unreachable");
429 case Oloadsb:
430 case Oloadub:
431 case Oloadsh:
432 case Oloaduh:
433 i->op = ext;
434 break;
435 case Oloadsw:
436 case Oloaduw:
437 if (i->cls == Kl) {
438 i->op = ext;
439 break;
440 }
441 /* fall through */
442 case Oload:
443 i->op = Ocopy;
444 break;
445 }
446 i->arg[0] = i->arg[1];
447 i->arg[1] = R;
448 }
449 }
450 vgrow(&ib, ++nt);
451 ib[nt-1] = *i;
452 }
453 b->nins = nt;
454 idup(&b->ins, ib, nt);
455 }
456 vfree(ib);
457 vfree(ilog);
458 if (debug['M']) {
459 fprintf(stderr, "\n> After load elimination:\n");
460 printfn(fn, stderr);
461 }
462}
463