| 1 | // © 2022 and later: Unicode, Inc. and others. |
| 2 | // License & terms of use: http://www.unicode.org/copyright.html |
| 3 | |
| 4 | #include "unicode/utypes.h" |
| 5 | |
| 6 | #if !UCONFIG_NO_BREAK_ITERATION |
| 7 | |
| 8 | #include "cmemory.h" |
| 9 | #include "mlbe.h" |
| 10 | #include "uassert.h" |
| 11 | #include "ubrkimpl.h" |
| 12 | #include "unicode/resbund.h" |
| 13 | #include "unicode/udata.h" |
| 14 | #include "unicode/utf16.h" |
| 15 | #include "uresimp.h" |
| 16 | #include "util.h" |
| 17 | #include "uvectr32.h" |
| 18 | |
| 19 | U_NAMESPACE_BEGIN |
| 20 | |
| 21 | enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 }; |
| 22 | |
| 23 | MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet, |
| 24 | const UnicodeSet &closePunctuationSet, UErrorCode &status) |
| 25 | : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet), |
| 26 | fClosePunctuationSet(closePunctuationSet), |
| 27 | fNegativeSum(0) { |
| 28 | if (U_FAILURE(status)) { |
| 29 | return; |
| 30 | } |
| 31 | loadMLModel(status); |
| 32 | } |
| 33 | |
| 34 | MlBreakEngine::~MlBreakEngine() {} |
| 35 | |
| 36 | int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd, |
| 37 | UVector32 &foundBreaks, const UnicodeString &inString, |
| 38 | const LocalPointer<UVector32> &inputMap, |
| 39 | UErrorCode &status) const { |
| 40 | if (U_FAILURE(status)) { |
| 41 | return 0; |
| 42 | } |
| 43 | if (rangeStart >= rangeEnd) { |
| 44 | status = U_ILLEGAL_ARGUMENT_ERROR; |
| 45 | return 0; |
| 46 | } |
| 47 | |
| 48 | UVector32 boundary(inString.countChar32() + 1, status); |
| 49 | if (U_FAILURE(status)) { |
| 50 | return 0; |
| 51 | } |
| 52 | int32_t numBreaks = 0; |
| 53 | int32_t codePointLength = inString.countChar32(); |
| 54 | // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint. |
| 55 | // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding |
| 56 | // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After |
| 57 | // moving forward, finally the last six values in the indexList are |
| 58 | // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1". |
| 59 | int32_t indexSize = codePointLength + 4; |
| 60 | int32_t *indexList = (int32_t *)uprv_malloc(indexSize * sizeof(int32_t)); |
| 61 | if (indexList == nullptr) { |
| 62 | status = U_MEMORY_ALLOCATION_ERROR; |
| 63 | return 0; |
| 64 | } |
| 65 | int32_t numCodeUnits = initIndexList(inString, indexList, status); |
| 66 | |
| 67 | // Add a break for the start. |
| 68 | boundary.addElement(0, status); |
| 69 | numBreaks++; |
| 70 | if (U_FAILURE(status)) return 0; |
| 71 | |
| 72 | for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) { |
| 73 | numBreaks = |
| 74 | evaluateBreakpoint(inString, indexList, idx, numCodeUnits, numBreaks, boundary, status); |
| 75 | if (idx + 4 < codePointLength) { |
| 76 | indexList[idx + 6] = numCodeUnits; |
| 77 | numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6])); |
| 78 | } |
| 79 | } |
| 80 | uprv_free(indexList); |
| 81 | |
| 82 | if (U_FAILURE(status)) return 0; |
| 83 | |
| 84 | // Add a break for the end if there is not one there already. |
| 85 | if (boundary.lastElementi() != inString.countChar32()) { |
| 86 | boundary.addElement(inString.countChar32(), status); |
| 87 | numBreaks++; |
| 88 | } |
| 89 | |
| 90 | int32_t prevCPPos = -1; |
| 91 | int32_t prevUTextPos = -1; |
| 92 | int32_t correctedNumBreaks = 0; |
| 93 | for (int32_t i = 0; i < numBreaks; i++) { |
| 94 | int32_t cpPos = boundary.elementAti(i); |
| 95 | int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart; |
| 96 | U_ASSERT(cpPos > prevCPPos); |
| 97 | U_ASSERT(utextPos >= prevUTextPos); |
| 98 | |
| 99 | if (utextPos > prevUTextPos) { |
| 100 | if (utextPos != rangeStart || |
| 101 | (utextPos > 0 && |
| 102 | fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) { |
| 103 | foundBreaks.push(utextPos, status); |
| 104 | correctedNumBreaks++; |
| 105 | } |
| 106 | } else { |
| 107 | // Normalization expanded the input text, the dictionary found a boundary |
| 108 | // within the expansion, giving two boundaries with the same index in the |
| 109 | // original text. Ignore the second. See ticket #12918. |
| 110 | --numBreaks; |
| 111 | } |
| 112 | prevCPPos = cpPos; |
| 113 | prevUTextPos = utextPos; |
| 114 | } |
| 115 | (void)prevCPPos; // suppress compiler warnings about unused variable |
| 116 | |
| 117 | UChar32 nextChar = utext_char32At(inText, rangeEnd); |
| 118 | if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) { |
| 119 | // In phrase breaking, there has to be a breakpoint between Cj character and |
| 120 | // the number/open punctuation. |
| 121 | // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「 |
| 122 | // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9 |
| 123 | // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U |
| 124 | if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) { |
| 125 | foundBreaks.popi(); |
| 126 | correctedNumBreaks--; |
| 127 | } |
| 128 | } |
| 129 | |
| 130 | return correctedNumBreaks; |
| 131 | } |
| 132 | |
| 133 | int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList, |
| 134 | int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks, |
| 135 | UVector32 &boundary, UErrorCode &status) const { |
| 136 | if (U_FAILURE(status)) { |
| 137 | return numBreaks; |
| 138 | } |
| 139 | int32_t start = 0, end = 0; |
| 140 | int32_t score = fNegativeSum; |
| 141 | |
| 142 | for (int i = 0; i < 6; i++) { |
| 143 | // UW1 ~ UW6 |
| 144 | start = startIdx + i; |
| 145 | if (indexList[start] != -1) { |
| 146 | end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits; |
| 147 | score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti( |
| 148 | inString.tempSubString(indexList[start], end - indexList[start])); |
| 149 | } |
| 150 | } |
| 151 | for (int i = 0; i < 3; i++) { |
| 152 | // BW1 ~ BW3 |
| 153 | start = startIdx + i + 1; |
| 154 | if (indexList[start] != -1 && indexList[start + 1] != -1) { |
| 155 | end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits; |
| 156 | score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti( |
| 157 | inString.tempSubString(indexList[start], end - indexList[start])); |
| 158 | } |
| 159 | } |
| 160 | for (int i = 0; i < 4; i++) { |
| 161 | // TW1 ~ TW4 |
| 162 | start = startIdx + i; |
| 163 | if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) { |
| 164 | end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits; |
| 165 | score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti( |
| 166 | inString.tempSubString(indexList[start], end - indexList[start])); |
| 167 | } |
| 168 | } |
| 169 | |
| 170 | if (score > 0) { |
| 171 | boundary.addElement(startIdx + 1, status); |
| 172 | numBreaks++; |
| 173 | } |
| 174 | return numBreaks; |
| 175 | } |
| 176 | |
| 177 | int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList, |
| 178 | UErrorCode &status) const { |
| 179 | if (U_FAILURE(status)) { |
| 180 | return 0; |
| 181 | } |
| 182 | int32_t index = 0; |
| 183 | int32_t length = inString.countChar32(); |
| 184 | // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff. |
| 185 | uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t)); |
| 186 | if (length > 0) { |
| 187 | indexList[2] = 0; |
| 188 | index = U16_LENGTH(inString.char32At(0)); |
| 189 | if (length > 1) { |
| 190 | indexList[3] = index; |
| 191 | index += U16_LENGTH(inString.char32At(index)); |
| 192 | if (length > 2) { |
| 193 | indexList[4] = index; |
| 194 | index += U16_LENGTH(inString.char32At(index)); |
| 195 | if (length > 3) { |
| 196 | indexList[5] = index; |
| 197 | index += U16_LENGTH(inString.char32At(index)); |
| 198 | } |
| 199 | } |
| 200 | } |
| 201 | } |
| 202 | return index; |
| 203 | } |
| 204 | |
| 205 | void MlBreakEngine::loadMLModel(UErrorCode &error) { |
| 206 | // BudouX's model consists of thirteen categories, each of which is make up of pairs of the |
| 207 | // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and |
| 208 | // value to represent the feature and the corresponding score respectively. |
| 209 | |
| 210 | if (U_FAILURE(error)) return; |
| 211 | |
| 212 | UnicodeString key; |
| 213 | StackUResourceBundle stackTempBundle; |
| 214 | ResourceDataValue modelKey; |
| 215 | |
| 216 | LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml" , &error)); |
| 217 | UResourceBundle *rb = rbp.getAlias(); |
| 218 | if (U_FAILURE(error)) return; |
| 219 | |
| 220 | int32_t index = 0; |
| 221 | initKeyValue(rb, "UW1Keys" , "UW1Values" , fModel[index++], error); |
| 222 | initKeyValue(rb, "UW2Keys" , "UW2Values" , fModel[index++], error); |
| 223 | initKeyValue(rb, "UW3Keys" , "UW3Values" , fModel[index++], error); |
| 224 | initKeyValue(rb, "UW4Keys" , "UW4Values" , fModel[index++], error); |
| 225 | initKeyValue(rb, "UW5Keys" , "UW5Values" , fModel[index++], error); |
| 226 | initKeyValue(rb, "UW6Keys" , "UW6Values" , fModel[index++], error); |
| 227 | initKeyValue(rb, "BW1Keys" , "BW1Values" , fModel[index++], error); |
| 228 | initKeyValue(rb, "BW2Keys" , "BW2Values" , fModel[index++], error); |
| 229 | initKeyValue(rb, "BW3Keys" , "BW3Values" , fModel[index++], error); |
| 230 | initKeyValue(rb, "TW1Keys" , "TW1Values" , fModel[index++], error); |
| 231 | initKeyValue(rb, "TW2Keys" , "TW2Values" , fModel[index++], error); |
| 232 | initKeyValue(rb, "TW3Keys" , "TW3Values" , fModel[index++], error); |
| 233 | initKeyValue(rb, "TW4Keys" , "TW4Values" , fModel[index++], error); |
| 234 | fNegativeSum /= 2; |
| 235 | } |
| 236 | |
| 237 | void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName, |
| 238 | Hashtable &model, UErrorCode &error) { |
| 239 | int32_t keySize = 0; |
| 240 | int32_t valueSize = 0; |
| 241 | int32_t stringLength = 0; |
| 242 | UnicodeString key; |
| 243 | StackUResourceBundle stackTempBundle; |
| 244 | ResourceDataValue modelKey; |
| 245 | |
| 246 | // get modelValues |
| 247 | LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error)); |
| 248 | const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error); |
| 249 | if (U_FAILURE(error)) return; |
| 250 | |
| 251 | // get modelKeys |
| 252 | ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error); |
| 253 | ResourceArray stringArray = modelKey.getArray(error); |
| 254 | keySize = stringArray.getSize(); |
| 255 | if (U_FAILURE(error)) return; |
| 256 | |
| 257 | for (int32_t idx = 0; idx < keySize; idx++) { |
| 258 | stringArray.getValue(idx, modelKey); |
| 259 | key = UnicodeString(modelKey.getString(stringLength, error)); |
| 260 | if (U_SUCCESS(error)) { |
| 261 | U_ASSERT(idx < valueSize); |
| 262 | fNegativeSum -= value[idx]; |
| 263 | model.puti(key, value[idx], error); |
| 264 | } |
| 265 | } |
| 266 | } |
| 267 | |
| 268 | U_NAMESPACE_END |
| 269 | |
| 270 | #endif /* #if !UCONFIG_NO_BREAK_ITERATION */ |
| 271 | |