1/*
2** C data arithmetic.
3** Copyright (C) 2005-2014 Mike Pall. See Copyright Notice in luajit.h
4*/
5
6#include "lj_obj.h"
7
8#if LJ_HASFFI
9
10#include "lj_gc.h"
11#include "lj_err.h"
12#include "lj_tab.h"
13#include "lj_meta.h"
14#include "lj_ctype.h"
15#include "lj_cconv.h"
16#include "lj_cdata.h"
17#include "lj_carith.h"
18
19/* -- C data arithmetic --------------------------------------------------- */
20
21/* Binary operands of an operator converted to ctypes. */
22typedef struct CDArith {
23 uint8_t *p[2];
24 CType *ct[2];
25} CDArith;
26
27/* Check arguments for arithmetic metamethods. */
28static int carith_checkarg(lua_State *L, CTState *cts, CDArith *ca)
29{
30 TValue *o = L->base;
31 int ok = 1;
32 MSize i;
33 if (o+1 >= L->top)
34 lj_err_argt(L, 1, LUA_TCDATA);
35 for (i = 0; i < 2; i++, o++) {
36 if (tviscdata(o)) {
37 GCcdata *cd = cdataV(o);
38 CTypeID id = (CTypeID)cd->ctypeid;
39 CType *ct = ctype_raw(cts, id);
40 uint8_t *p = (uint8_t *)cdataptr(cd);
41 if (ctype_isptr(ct->info)) {
42 p = (uint8_t *)cdata_getptr(p, ct->size);
43 if (ctype_isref(ct->info)) ct = ctype_rawchild(cts, ct);
44 } else if (ctype_isfunc(ct->info)) {
45 p = (uint8_t *)*(void **)p;
46 ct = ctype_get(cts,
47 lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|id), CTSIZE_PTR));
48 }
49 if (ctype_isenum(ct->info)) ct = ctype_child(cts, ct);
50 ca->ct[i] = ct;
51 ca->p[i] = p;
52 } else if (tvisint(o)) {
53 ca->ct[i] = ctype_get(cts, CTID_INT32);
54 ca->p[i] = (uint8_t *)&o->i;
55 } else if (tvisnum(o)) {
56 ca->ct[i] = ctype_get(cts, CTID_DOUBLE);
57 ca->p[i] = (uint8_t *)&o->n;
58 } else if (tvisnil(o)) {
59 ca->ct[i] = ctype_get(cts, CTID_P_VOID);
60 ca->p[i] = (uint8_t *)0;
61 } else if (tvisstr(o)) {
62 TValue *o2 = i == 0 ? o+1 : o-1;
63 CType *ct = ctype_raw(cts, cdataV(o2)->ctypeid);
64 ca->ct[i] = NULL;
65 ca->p[i] = NULL;
66 ok = 0;
67 if (ctype_isenum(ct->info)) {
68 CTSize ofs;
69 CType *cct = lj_ctype_getfield(cts, ct, strV(o), &ofs);
70 if (cct && ctype_isconstval(cct->info)) {
71 ca->ct[i] = ctype_child(cts, cct);
72 ca->p[i] = (uint8_t *)&cct->size; /* Assumes ct does not grow. */
73 ok = 1;
74 } else {
75 ca->ct[1-i] = ct; /* Use enum to improve error message. */
76 ca->p[1-i] = NULL;
77 break;
78 }
79 }
80 } else {
81 ca->ct[i] = NULL;
82 ca->p[i] = NULL;
83 ok = 0;
84 }
85 }
86 return ok;
87}
88
89/* Pointer arithmetic. */
90static int carith_ptr(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
91{
92 CType *ctp = ca->ct[0];
93 uint8_t *pp = ca->p[0];
94 ptrdiff_t idx;
95 CTSize sz;
96 CTypeID id;
97 GCcdata *cd;
98 if (ctype_isptr(ctp->info) || ctype_isrefarray(ctp->info)) {
99 if ((mm == MM_sub || mm == MM_eq || mm == MM_lt || mm == MM_le) &&
100 (ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
101 uint8_t *pp2 = ca->p[1];
102 if (mm == MM_eq) { /* Pointer equality. Incompatible pointers are ok. */
103 setboolV(L->top-1, (pp == pp2));
104 return 1;
105 }
106 if (!lj_cconv_compatptr(cts, ctp, ca->ct[1], CCF_IGNQUAL))
107 return 0;
108 if (mm == MM_sub) { /* Pointer difference. */
109 intptr_t diff;
110 sz = lj_ctype_size(cts, ctype_cid(ctp->info)); /* Element size. */
111 if (sz == 0 || sz == CTSIZE_INVALID)
112 return 0;
113 diff = ((intptr_t)pp - (intptr_t)pp2) / (int32_t)sz;
114 /* All valid pointer differences on x64 are in (-2^47, +2^47),
115 ** which fits into a double without loss of precision.
116 */
117 setintptrV(L->top-1, (int32_t)diff);
118 return 1;
119 } else if (mm == MM_lt) { /* Pointer comparison (unsigned). */
120 setboolV(L->top-1, ((uintptr_t)pp < (uintptr_t)pp2));
121 return 1;
122 } else {
123 lua_assert(mm == MM_le);
124 setboolV(L->top-1, ((uintptr_t)pp <= (uintptr_t)pp2));
125 return 1;
126 }
127 }
128 if (!((mm == MM_add || mm == MM_sub) && ctype_isnum(ca->ct[1]->info)))
129 return 0;
130 lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[1],
131 (uint8_t *)&idx, ca->p[1], 0);
132 if (mm == MM_sub) idx = -idx;
133 } else if (mm == MM_add && ctype_isnum(ctp->info) &&
134 (ctype_isptr(ca->ct[1]->info) || ctype_isrefarray(ca->ct[1]->info))) {
135 /* Swap pointer and index. */
136 ctp = ca->ct[1]; pp = ca->p[1];
137 lj_cconv_ct_ct(cts, ctype_get(cts, CTID_INT_PSZ), ca->ct[0],
138 (uint8_t *)&idx, ca->p[0], 0);
139 } else {
140 return 0;
141 }
142 sz = lj_ctype_size(cts, ctype_cid(ctp->info)); /* Element size. */
143 if (sz == CTSIZE_INVALID)
144 return 0;
145 pp += idx*(int32_t)sz; /* Compute pointer + index. */
146 id = lj_ctype_intern(cts, CTINFO(CT_PTR, CTALIGN_PTR|ctype_cid(ctp->info)),
147 CTSIZE_PTR);
148 cd = lj_cdata_new(cts, id, CTSIZE_PTR);
149 *(uint8_t **)cdataptr(cd) = pp;
150 setcdataV(L, L->top-1, cd);
151 lj_gc_check(L);
152 return 1;
153}
154
155/* 64 bit integer arithmetic. */
156static int carith_int64(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
157{
158 if (ctype_isnum(ca->ct[0]->info) && ca->ct[0]->size <= 8 &&
159 ctype_isnum(ca->ct[1]->info) && ca->ct[1]->size <= 8) {
160 CTypeID id = (((ca->ct[0]->info & CTF_UNSIGNED) && ca->ct[0]->size == 8) ||
161 ((ca->ct[1]->info & CTF_UNSIGNED) && ca->ct[1]->size == 8)) ?
162 CTID_UINT64 : CTID_INT64;
163 CType *ct = ctype_get(cts, id);
164 GCcdata *cd;
165 uint64_t u0, u1, *up;
166 lj_cconv_ct_ct(cts, ct, ca->ct[0], (uint8_t *)&u0, ca->p[0], 0);
167 if (mm != MM_unm)
168 lj_cconv_ct_ct(cts, ct, ca->ct[1], (uint8_t *)&u1, ca->p[1], 0);
169 switch (mm) {
170 case MM_eq:
171 setboolV(L->top-1, (u0 == u1));
172 return 1;
173 case MM_lt:
174 setboolV(L->top-1,
175 id == CTID_INT64 ? ((int64_t)u0 < (int64_t)u1) : (u0 < u1));
176 return 1;
177 case MM_le:
178 setboolV(L->top-1,
179 id == CTID_INT64 ? ((int64_t)u0 <= (int64_t)u1) : (u0 <= u1));
180 return 1;
181 default: break;
182 }
183 cd = lj_cdata_new(cts, id, 8);
184 up = (uint64_t *)cdataptr(cd);
185 setcdataV(L, L->top-1, cd);
186 switch (mm) {
187 case MM_add: *up = u0 + u1; break;
188 case MM_sub: *up = u0 - u1; break;
189 case MM_mul: *up = u0 * u1; break;
190 case MM_div:
191 if (id == CTID_INT64)
192 *up = (uint64_t)lj_carith_divi64((int64_t)u0, (int64_t)u1);
193 else
194 *up = lj_carith_divu64(u0, u1);
195 break;
196 case MM_mod:
197 if (id == CTID_INT64)
198 *up = (uint64_t)lj_carith_modi64((int64_t)u0, (int64_t)u1);
199 else
200 *up = lj_carith_modu64(u0, u1);
201 break;
202 case MM_pow:
203 if (id == CTID_INT64)
204 *up = (uint64_t)lj_carith_powi64((int64_t)u0, (int64_t)u1);
205 else
206 *up = lj_carith_powu64(u0, u1);
207 break;
208 case MM_unm: *up = (uint64_t)-(int64_t)u0; break;
209 default: lua_assert(0); break;
210 }
211 lj_gc_check(L);
212 return 1;
213 }
214 return 0;
215}
216
217/* Handle ctype arithmetic metamethods. */
218static int lj_carith_meta(lua_State *L, CTState *cts, CDArith *ca, MMS mm)
219{
220 cTValue *tv = NULL;
221 if (tviscdata(L->base)) {
222 CTypeID id = cdataV(L->base)->ctypeid;
223 CType *ct = ctype_raw(cts, id);
224 if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
225 tv = lj_ctype_meta(cts, id, mm);
226 }
227 if (!tv && L->base+1 < L->top && tviscdata(L->base+1)) {
228 CTypeID id = cdataV(L->base+1)->ctypeid;
229 CType *ct = ctype_raw(cts, id);
230 if (ctype_isptr(ct->info)) id = ctype_cid(ct->info);
231 tv = lj_ctype_meta(cts, id, mm);
232 }
233 if (!tv) {
234 const char *repr[2];
235 int i, isenum = -1, isstr = -1;
236 if (mm == MM_eq) { /* Equality checks never raise an error. */
237 setboolV(L->top-1, 0);
238 return 1;
239 }
240 for (i = 0; i < 2; i++) {
241 if (ca->ct[i] && tviscdata(L->base+i)) {
242 if (ctype_isenum(ca->ct[i]->info)) isenum = i;
243 repr[i] = strdata(lj_ctype_repr(L, ctype_typeid(cts, ca->ct[i]), NULL));
244 } else {
245 if (tvisstr(&L->base[i])) isstr = i;
246 repr[i] = lj_typename(&L->base[i]);
247 }
248 }
249 if ((isenum ^ isstr) == 1)
250 lj_err_callerv(L, LJ_ERR_FFI_BADCONV, repr[isstr], repr[isenum]);
251 lj_err_callerv(L, mm == MM_len ? LJ_ERR_FFI_BADLEN :
252 mm == MM_concat ? LJ_ERR_FFI_BADCONCAT :
253 mm < MM_add ? LJ_ERR_FFI_BADCOMP : LJ_ERR_FFI_BADARITH,
254 repr[0], repr[1]);
255 }
256 return lj_meta_tailcall(L, tv);
257}
258
259/* Arithmetic operators for cdata. */
260int lj_carith_op(lua_State *L, MMS mm)
261{
262 CTState *cts = ctype_cts(L);
263 CDArith ca;
264 if (carith_checkarg(L, cts, &ca)) {
265 if (carith_int64(L, cts, &ca, mm) || carith_ptr(L, cts, &ca, mm)) {
266 copyTV(L, &G(L)->tmptv2, L->top-1); /* Remember for trace recorder. */
267 return 1;
268 }
269 }
270 return lj_carith_meta(L, cts, &ca, mm);
271}
272
273/* -- 64 bit integer arithmetic helpers ----------------------------------- */
274
275#if LJ_32 && LJ_HASJIT
276/* Signed/unsigned 64 bit multiplication. */
277int64_t lj_carith_mul64(int64_t a, int64_t b)
278{
279 return a * b;
280}
281#endif
282
283/* Unsigned 64 bit division. */
284uint64_t lj_carith_divu64(uint64_t a, uint64_t b)
285{
286 if (b == 0) return U64x(80000000,00000000);
287 return a / b;
288}
289
290/* Signed 64 bit division. */
291int64_t lj_carith_divi64(int64_t a, int64_t b)
292{
293 if (b == 0 || (a == (int64_t)U64x(80000000,00000000) && b == -1))
294 return U64x(80000000,00000000);
295 return a / b;
296}
297
298/* Unsigned 64 bit modulo. */
299uint64_t lj_carith_modu64(uint64_t a, uint64_t b)
300{
301 if (b == 0) return U64x(80000000,00000000);
302 return a % b;
303}
304
305/* Signed 64 bit modulo. */
306int64_t lj_carith_modi64(int64_t a, int64_t b)
307{
308 if (b == 0) return U64x(80000000,00000000);
309 if (a == (int64_t)U64x(80000000,00000000) && b == -1) return 0;
310 return a % b;
311}
312
313/* Unsigned 64 bit x^k. */
314uint64_t lj_carith_powu64(uint64_t x, uint64_t k)
315{
316 uint64_t y;
317 if (k == 0)
318 return 1;
319 for (; (k & 1) == 0; k >>= 1) x *= x;
320 y = x;
321 if ((k >>= 1) != 0) {
322 for (;;) {
323 x *= x;
324 if (k == 1) break;
325 if (k & 1) y *= x;
326 k >>= 1;
327 }
328 y *= x;
329 }
330 return y;
331}
332
333/* Signed 64 bit x^k. */
334int64_t lj_carith_powi64(int64_t x, int64_t k)
335{
336 if (k == 0)
337 return 1;
338 if (k < 0) {
339 if (x == 0)
340 return U64x(7fffffff,ffffffff);
341 else if (x == 1)
342 return 1;
343 else if (x == -1)
344 return (k & 1) ? -1 : 1;
345 else
346 return 0;
347 }
348 return (int64_t)lj_carith_powu64((uint64_t)x, (uint64_t)k);
349}
350
351#endif
352