1 | // © 2022 and later: Unicode, Inc. and others. |
2 | // License & terms of use: http://www.unicode.org/copyright.html |
3 | |
4 | #include "unicode/utypes.h" |
5 | |
6 | #if !UCONFIG_NO_BREAK_ITERATION |
7 | |
8 | #include "cmemory.h" |
9 | #include "mlbe.h" |
10 | #include "uassert.h" |
11 | #include "ubrkimpl.h" |
12 | #include "unicode/resbund.h" |
13 | #include "unicode/udata.h" |
14 | #include "unicode/utf16.h" |
15 | #include "uresimp.h" |
16 | #include "util.h" |
17 | #include "uvectr32.h" |
18 | |
19 | U_NAMESPACE_BEGIN |
20 | |
21 | enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 }; |
22 | |
23 | MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet, |
24 | const UnicodeSet &closePunctuationSet, UErrorCode &status) |
25 | : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet), |
26 | fClosePunctuationSet(closePunctuationSet), |
27 | fNegativeSum(0) { |
28 | if (U_FAILURE(status)) { |
29 | return; |
30 | } |
31 | loadMLModel(status); |
32 | } |
33 | |
34 | MlBreakEngine::~MlBreakEngine() {} |
35 | |
36 | int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd, |
37 | UVector32 &foundBreaks, const UnicodeString &inString, |
38 | const LocalPointer<UVector32> &inputMap, |
39 | UErrorCode &status) const { |
40 | if (U_FAILURE(status)) { |
41 | return 0; |
42 | } |
43 | if (rangeStart >= rangeEnd) { |
44 | status = U_ILLEGAL_ARGUMENT_ERROR; |
45 | return 0; |
46 | } |
47 | |
48 | UVector32 boundary(inString.countChar32() + 1, status); |
49 | if (U_FAILURE(status)) { |
50 | return 0; |
51 | } |
52 | int32_t numBreaks = 0; |
53 | int32_t codePointLength = inString.countChar32(); |
54 | // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint. |
55 | // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding |
56 | // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After |
57 | // moving forward, finally the last six values in the indexList are |
58 | // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1". |
59 | int32_t indexSize = codePointLength + 4; |
60 | int32_t *indexList = (int32_t *)uprv_malloc(indexSize * sizeof(int32_t)); |
61 | if (indexList == nullptr) { |
62 | status = U_MEMORY_ALLOCATION_ERROR; |
63 | return 0; |
64 | } |
65 | int32_t numCodeUnits = initIndexList(inString, indexList, status); |
66 | |
67 | // Add a break for the start. |
68 | boundary.addElement(0, status); |
69 | numBreaks++; |
70 | if (U_FAILURE(status)) return 0; |
71 | |
72 | for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) { |
73 | numBreaks = |
74 | evaluateBreakpoint(inString, indexList, idx, numCodeUnits, numBreaks, boundary, status); |
75 | if (idx + 4 < codePointLength) { |
76 | indexList[idx + 6] = numCodeUnits; |
77 | numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6])); |
78 | } |
79 | } |
80 | uprv_free(indexList); |
81 | |
82 | if (U_FAILURE(status)) return 0; |
83 | |
84 | // Add a break for the end if there is not one there already. |
85 | if (boundary.lastElementi() != inString.countChar32()) { |
86 | boundary.addElement(inString.countChar32(), status); |
87 | numBreaks++; |
88 | } |
89 | |
90 | int32_t prevCPPos = -1; |
91 | int32_t prevUTextPos = -1; |
92 | int32_t correctedNumBreaks = 0; |
93 | for (int32_t i = 0; i < numBreaks; i++) { |
94 | int32_t cpPos = boundary.elementAti(i); |
95 | int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart; |
96 | U_ASSERT(cpPos > prevCPPos); |
97 | U_ASSERT(utextPos >= prevUTextPos); |
98 | |
99 | if (utextPos > prevUTextPos) { |
100 | if (utextPos != rangeStart || |
101 | (utextPos > 0 && |
102 | fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) { |
103 | foundBreaks.push(utextPos, status); |
104 | correctedNumBreaks++; |
105 | } |
106 | } else { |
107 | // Normalization expanded the input text, the dictionary found a boundary |
108 | // within the expansion, giving two boundaries with the same index in the |
109 | // original text. Ignore the second. See ticket #12918. |
110 | --numBreaks; |
111 | } |
112 | prevCPPos = cpPos; |
113 | prevUTextPos = utextPos; |
114 | } |
115 | (void)prevCPPos; // suppress compiler warnings about unused variable |
116 | |
117 | UChar32 nextChar = utext_char32At(inText, rangeEnd); |
118 | if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) { |
119 | // In phrase breaking, there has to be a breakpoint between Cj character and |
120 | // the number/open punctuation. |
121 | // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「 |
122 | // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9 |
123 | // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U |
124 | if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) { |
125 | foundBreaks.popi(); |
126 | correctedNumBreaks--; |
127 | } |
128 | } |
129 | |
130 | return correctedNumBreaks; |
131 | } |
132 | |
133 | int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList, |
134 | int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks, |
135 | UVector32 &boundary, UErrorCode &status) const { |
136 | if (U_FAILURE(status)) { |
137 | return numBreaks; |
138 | } |
139 | int32_t start = 0, end = 0; |
140 | int32_t score = fNegativeSum; |
141 | |
142 | for (int i = 0; i < 6; i++) { |
143 | // UW1 ~ UW6 |
144 | start = startIdx + i; |
145 | if (indexList[start] != -1) { |
146 | end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits; |
147 | score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti( |
148 | inString.tempSubString(indexList[start], end - indexList[start])); |
149 | } |
150 | } |
151 | for (int i = 0; i < 3; i++) { |
152 | // BW1 ~ BW3 |
153 | start = startIdx + i + 1; |
154 | if (indexList[start] != -1 && indexList[start + 1] != -1) { |
155 | end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits; |
156 | score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti( |
157 | inString.tempSubString(indexList[start], end - indexList[start])); |
158 | } |
159 | } |
160 | for (int i = 0; i < 4; i++) { |
161 | // TW1 ~ TW4 |
162 | start = startIdx + i; |
163 | if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) { |
164 | end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits; |
165 | score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti( |
166 | inString.tempSubString(indexList[start], end - indexList[start])); |
167 | } |
168 | } |
169 | |
170 | if (score > 0) { |
171 | boundary.addElement(startIdx + 1, status); |
172 | numBreaks++; |
173 | } |
174 | return numBreaks; |
175 | } |
176 | |
177 | int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList, |
178 | UErrorCode &status) const { |
179 | if (U_FAILURE(status)) { |
180 | return 0; |
181 | } |
182 | int32_t index = 0; |
183 | int32_t length = inString.countChar32(); |
184 | // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff. |
185 | uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t)); |
186 | if (length > 0) { |
187 | indexList[2] = 0; |
188 | index = U16_LENGTH(inString.char32At(0)); |
189 | if (length > 1) { |
190 | indexList[3] = index; |
191 | index += U16_LENGTH(inString.char32At(index)); |
192 | if (length > 2) { |
193 | indexList[4] = index; |
194 | index += U16_LENGTH(inString.char32At(index)); |
195 | if (length > 3) { |
196 | indexList[5] = index; |
197 | index += U16_LENGTH(inString.char32At(index)); |
198 | } |
199 | } |
200 | } |
201 | } |
202 | return index; |
203 | } |
204 | |
205 | void MlBreakEngine::loadMLModel(UErrorCode &error) { |
206 | // BudouX's model consists of thirteen categories, each of which is make up of pairs of the |
207 | // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and |
208 | // value to represent the feature and the corresponding score respectively. |
209 | |
210 | if (U_FAILURE(error)) return; |
211 | |
212 | UnicodeString key; |
213 | StackUResourceBundle stackTempBundle; |
214 | ResourceDataValue modelKey; |
215 | |
216 | LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml" , &error)); |
217 | UResourceBundle *rb = rbp.getAlias(); |
218 | if (U_FAILURE(error)) return; |
219 | |
220 | int32_t index = 0; |
221 | initKeyValue(rb, "UW1Keys" , "UW1Values" , fModel[index++], error); |
222 | initKeyValue(rb, "UW2Keys" , "UW2Values" , fModel[index++], error); |
223 | initKeyValue(rb, "UW3Keys" , "UW3Values" , fModel[index++], error); |
224 | initKeyValue(rb, "UW4Keys" , "UW4Values" , fModel[index++], error); |
225 | initKeyValue(rb, "UW5Keys" , "UW5Values" , fModel[index++], error); |
226 | initKeyValue(rb, "UW6Keys" , "UW6Values" , fModel[index++], error); |
227 | initKeyValue(rb, "BW1Keys" , "BW1Values" , fModel[index++], error); |
228 | initKeyValue(rb, "BW2Keys" , "BW2Values" , fModel[index++], error); |
229 | initKeyValue(rb, "BW3Keys" , "BW3Values" , fModel[index++], error); |
230 | initKeyValue(rb, "TW1Keys" , "TW1Values" , fModel[index++], error); |
231 | initKeyValue(rb, "TW2Keys" , "TW2Values" , fModel[index++], error); |
232 | initKeyValue(rb, "TW3Keys" , "TW3Values" , fModel[index++], error); |
233 | initKeyValue(rb, "TW4Keys" , "TW4Values" , fModel[index++], error); |
234 | fNegativeSum /= 2; |
235 | } |
236 | |
237 | void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName, |
238 | Hashtable &model, UErrorCode &error) { |
239 | int32_t keySize = 0; |
240 | int32_t valueSize = 0; |
241 | int32_t stringLength = 0; |
242 | UnicodeString key; |
243 | StackUResourceBundle stackTempBundle; |
244 | ResourceDataValue modelKey; |
245 | |
246 | // get modelValues |
247 | LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error)); |
248 | const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error); |
249 | if (U_FAILURE(error)) return; |
250 | |
251 | // get modelKeys |
252 | ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error); |
253 | ResourceArray stringArray = modelKey.getArray(error); |
254 | keySize = stringArray.getSize(); |
255 | if (U_FAILURE(error)) return; |
256 | |
257 | for (int32_t idx = 0; idx < keySize; idx++) { |
258 | stringArray.getValue(idx, modelKey); |
259 | key = UnicodeString(modelKey.getString(stringLength, error)); |
260 | if (U_SUCCESS(error)) { |
261 | U_ASSERT(idx < valueSize); |
262 | fNegativeSum -= value[idx]; |
263 | model.puti(key, value[idx], error); |
264 | } |
265 | } |
266 | } |
267 | |
268 | U_NAMESPACE_END |
269 | |
270 | #endif /* #if !UCONFIG_NO_BREAK_ITERATION */ |
271 | |