Source added

This commit is contained in:
Fr4nz D13trich 2025-11-20 09:26:33 +01:00
parent b2864b500e
commit ba28ca859e
8352 changed files with 1487182 additions and 1 deletions

1
core-util/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

View file

@ -0,0 +1,30 @@
plugins {
id("signal-library")
id("com.squareup.wire")
}
android {
namespace = "org.signal.core.util"
}
dependencies {
api(project(":core-util-jvm"))
implementation(libs.androidx.sqlite)
implementation(libs.androidx.documentfile)
testImplementation(libs.androidx.sqlite.framework)
testImplementation(testLibs.junit.junit)
testImplementation(testLibs.assertk)
testImplementation(testLibs.robolectric.robolectric)
}
wire {
kotlin {
javaInterop = true
}
sourcePath {
srcDir("src/main/protowire")
}
}

View file

21
core-util/proguard-rules.pro vendored Normal file
View file

@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View file

@ -0,0 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest
xmlns:android="http://schemas.android.com/apk/res/android">
</manifest>

View file

@ -0,0 +1,10 @@
package androidx.documentfile.provider
/**
* Located in androidx package as [TreeDocumentFile] is package protected.
*
* @return true if can be used like a tree document file (e.g., use content resolver queries)
*/
fun DocumentFile.isTreeDocumentFile(): Boolean {
return this is TreeDocumentFile
}

View file

@ -0,0 +1,22 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.Activity
import android.os.Build
import androidx.annotation.AnimRes
val Activity.OVERRIDE_TRANSITION_OPEN_COMPAT: Int get() = 0
val Activity.OVERRIDE_TRANSITION_CLOSE_COMPAT: Int get() = 1
fun Activity.overrideActivityTransitionCompat(overrideType: Int, @AnimRes enterAnim: Int, @AnimRes exitAnim: Int) {
if (Build.VERSION.SDK_INT >= 34) {
overrideActivityTransition(overrideType, enterAnim, exitAnim)
} else {
@Suppress("DEPRECATION")
overridePendingTransition(enterAnim, exitAnim)
}
}

View file

@ -0,0 +1,26 @@
package org.signal.core.util;
import android.app.ActivityManager;
import android.content.Context;
import android.content.Intent;
import androidx.annotation.NonNull;
import androidx.core.content.ContextCompat;
public final class AppUtil {
private AppUtil() {}
/**
* Restarts the application. Should generally only be used for internal tools.
*/
public static void restart(@NonNull Context context) {
String packageName = context.getPackageName();
Intent defaultIntent = context.getPackageManager().getLaunchIntentForPackage(packageName);
defaultIntent.addFlags(Intent.FLAG_ACTIVITY_NEW_TASK);
context.startActivity(defaultIntent);
Runtime.getRuntime().exit(0);
}
}

View file

@ -0,0 +1,88 @@
package org.signal.core.util
import android.database.Cursor
import kotlin.math.max
class AsciiArt {
private class Table(
private val columns: List<String>,
private val rows: List<List<String>>
) {
override fun toString(): String {
val columnWidths = columns.map { column -> column.length }.toIntArray()
rows.forEach { row: List<String> ->
columnWidths.forEachIndexed { index, currentMax ->
columnWidths[index] = max(row[index].length, currentMax)
}
}
val builder = StringBuilder()
columns.forEachIndexed { index, column ->
builder.append(COLUMN_DIVIDER).append(" ").append(rightPad(column, columnWidths[index])).append(" ")
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
columnWidths.forEach { width ->
builder.append(COLUMN_DIVIDER)
builder.append(ROW_DIVIDER.repeat(width + 2))
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
rows.forEach { row ->
row.forEachIndexed { index, column ->
builder.append(COLUMN_DIVIDER).append(" ").append(rightPad(column, columnWidths[index])).append(" ")
}
builder.append(COLUMN_DIVIDER)
builder.append("\n")
}
return builder.toString()
}
}
companion object {
private const val COLUMN_DIVIDER = "|"
private const val ROW_DIVIDER = "-"
/**
* Will return a string representing a table of the provided cursor. The caller is responsible for the lifecycle of the cursor.
*/
@JvmStatic
fun tableFor(cursor: Cursor): String {
val columns: MutableList<String> = mutableListOf()
val rows: MutableList<List<String>> = mutableListOf()
columns.addAll(cursor.columnNames)
while (cursor.moveToNext()) {
val row: MutableList<String> = mutableListOf()
for (i in 0 until columns.size) {
row += cursor.getString(i)
}
rows += row
}
return Table(columns, rows).toString()
}
private fun rightPad(value: String, length: Int): String {
if (value.length >= length) {
return value
}
val out = java.lang.StringBuilder(value)
while (out.length < length) {
out.append(" ")
}
return out.toString()
}
}
}

View file

@ -0,0 +1,146 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.util.Iterator;
public abstract class BreakIteratorCompat implements Iterable<CharSequence> {
public static final int DONE = -1;
private CharSequence charSequence;
public abstract int first();
public abstract int next();
public void setText(CharSequence charSequence) {
this.charSequence = charSequence;
}
public static BreakIteratorCompat getInstance() {
if (Build.VERSION.SDK_INT >= 24) {
return new AndroidIcuBreakIterator();
} else {
return new FallbackBreakIterator();
}
}
public int countBreaks() {
int breakCount = 0;
first();
while (next() != DONE) {
breakCount++;
}
return breakCount;
}
@Override
public @NonNull Iterator<CharSequence> iterator() {
return new Iterator<CharSequence>() {
int index1 = BreakIteratorCompat.this.first();
int index2 = BreakIteratorCompat.this.next();
@Override
public boolean hasNext() {
return index2 != DONE;
}
@Override
public CharSequence next() {
CharSequence c = index2 != DONE ? charSequence.subSequence(index1, index2) : "";
index1 = index2;
index2 = BreakIteratorCompat.this.next();
return c;
}
};
}
/**
* Take {@param atMost} graphemes from the start of string.
*/
public final CharSequence take(int atMost) {
if (atMost <= 0) return "";
StringBuilder stringBuilder = new StringBuilder(charSequence.length());
int count = 0;
for (CharSequence grapheme : this) {
stringBuilder.append(grapheme);
count++;
if (count >= atMost) break;
}
return stringBuilder.toString();
}
/**
* An BreakIteratorCompat implementation that delegates calls to `android.icu.text.BreakIterator`.
* This class handles grapheme clusters fine but requires Android API >= 24.
*/
@RequiresApi(24)
private static class AndroidIcuBreakIterator extends BreakIteratorCompat {
private final android.icu.text.BreakIterator breakIterator = android.icu.text.BreakIterator.getCharacterInstance();
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public void setText(CharSequence charSequence) {
super.setText(charSequence);
if (Build.VERSION.SDK_INT >= 29) {
breakIterator.setText(charSequence);
} else {
breakIterator.setText(charSequence.toString());
}
}
}
/**
* An BreakIteratorCompat implementation that delegates calls to `java.text.BreakIterator`.
* This class may or may not handle grapheme clusters well depending on the underlying implementation.
* In the emulator, API 23 implements ICU version of the BreakIterator so that it handles grapheme
* clusters fine. But API 21 implements RuleBasedIterator which does not handle grapheme clusters.
* <p>
* If it doesn't handle grapheme clusters correctly, in most cases the combined characters are
* broken up into pieces when the code tries to trim a string. For example, an emoji that is
* a combination of a person, gender and skin tone, trimming the character using this class may result
* in trimming the parts of the character, e.g. a dark skin frowning woman emoji may result in
* a neutral skin frowning woman emoji.
*/
private static class FallbackBreakIterator extends BreakIteratorCompat {
private final java.text.BreakIterator breakIterator = java.text.BreakIterator.getCharacterInstance();
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public void setText(CharSequence charSequence) {
super.setText(charSequence);
breakIterator.setText(charSequence.toString());
}
}
}

View file

@ -0,0 +1,44 @@
@file:JvmName("BundleExtensions")
package org.signal.core.util
import android.os.Build
import android.os.Bundle
import android.os.Parcelable
import java.io.Serializable
fun <T : Serializable> Bundle.getSerializableCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getSerializable(key, clazz)
} else {
@Suppress("DEPRECATION", "UNCHECKED_CAST")
this.getSerializable(key) as T?
}
}
fun <T : Parcelable> Bundle.getParcelableCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelable(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelable(key)
}
}
fun <T : Parcelable> Bundle.requireParcelableCompat(key: String, clazz: Class<T>): T {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelable(key, clazz)!!
} else {
@Suppress("DEPRECATION")
this.getParcelable(key)!!
}
}
fun <T : Parcelable> Bundle.getParcelableArrayListCompat(key: String, clazz: Class<T>): ArrayList<T>? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableArrayList(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableArrayList(key)
}
}

View file

@ -0,0 +1,76 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.text.InputFilter
import android.text.Spanned
/**
* An [InputFilter] that prevents the target text from growing beyond [byteLimit] bytes when using UTF-8 encoding.
*/
class ByteLimitInputFilter(private val byteLimit: Int) : InputFilter {
override fun filter(source: CharSequence?, start: Int, end: Int, dest: Spanned?, dstart: Int, dend: Int): CharSequence? {
if (source == null || dest == null) {
return null
}
val insertText = source.subSequence(start, end)
val beforeText = dest.subSequence(0, dstart)
val afterText = dest.subSequence(dend, dest.length)
val insertByteLength = insertText.utf8Size()
val beforeByteLength = beforeText.utf8Size()
val afterByteLength = afterText.utf8Size()
val resultByteSize = beforeByteLength + insertByteLength + afterByteLength
if (resultByteSize <= byteLimit) {
return null
}
val availableBytes = byteLimit - beforeByteLength - afterByteLength
if (availableBytes <= 0) {
return ""
}
return truncateToByteLimit(insertText, availableBytes)
}
private fun truncateToByteLimit(text: CharSequence, maxBytes: Int): CharSequence {
var byteCount = 0
var charIndex = 0
while (charIndex < text.length) {
val char = text[charIndex]
val charBytes = when {
char.code < 0x80 -> 1
char.code < 0x800 -> 2
char.isHighSurrogate() -> {
if (charIndex + 1 < text.length && text[charIndex + 1].isLowSurrogate()) {
4
} else {
3
}
}
char.isLowSurrogate() -> 3 // Treat orphaned low surrogate as 3 bytes
else -> 3
}
if (byteCount + charBytes > maxBytes) {
break
}
byteCount += charBytes
charIndex++
if (char.isHighSurrogate() && charIndex < text.length && text[charIndex].isLowSurrogate()) {
charIndex++
}
}
return text.subSequence(0, charIndex)
}
}

View file

@ -0,0 +1,44 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
/**
* A copy of [okio.utf8Size] that works on [CharSequence].
*/
fun CharSequence.utf8Size(): Int {
var result = 0
var i = 0
while (i < this.length) {
val c = this[i].code
if (c < 0x80) {
// A 7-bit character with 1 byte.
result++
i++
} else if (c < 0x800) {
// An 11-bit character with 2 bytes.
result += 2
i++
} else if (c < 0xd800 || c > 0xdfff) {
// A 16-bit character with 3 bytes.
result += 3
i++
} else {
val low = if (i + 1 < this.length) this[i + 1].code else 0
if (c > 0xdbff || low < 0xdc00 || low > 0xdfff) {
// A malformed surrogate, which yields '?'.
result++
i++
} else {
// A 21-bit character with 4 bytes.
result += 4
i += 2
}
}
}
return result
}

View file

@ -0,0 +1,124 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.util.Iterator;
/**
* Iterates over a string treating a surrogate pair and a grapheme cluster a single character.
*/
public final class CharacterIterable implements Iterable<String> {
private final String string;
public CharacterIterable(@NonNull String string) {
this.string = string;
}
@Override
public @NonNull Iterator<String> iterator() {
return new CharacterIterator();
}
private class CharacterIterator implements Iterator<String> {
private static final int UNINITIALIZED = -2;
private final BreakIteratorCompat breakIterator;
private int lastIndex = UNINITIALIZED;
CharacterIterator() {
this.breakIterator = Build.VERSION.SDK_INT >= 24 ? new AndroidIcuBreakIterator(string)
: new FallbackBreakIterator(string);
}
@Override
public boolean hasNext() {
if (lastIndex == UNINITIALIZED) {
lastIndex = breakIterator.first();
}
return !breakIterator.isDone(lastIndex);
}
@Override
public String next() {
int firstIndex = lastIndex;
lastIndex = breakIterator.next();
return string.substring(firstIndex, lastIndex);
}
}
private interface BreakIteratorCompat {
int first();
int next();
boolean isDone(int index);
}
/**
* An BreakIteratorCompat implementation that delegates calls to `android.icu.text.BreakIterator`.
* This class handles grapheme clusters fine but requires Android API >= 24.
*/
@RequiresApi(24)
private static class AndroidIcuBreakIterator implements BreakIteratorCompat {
private final android.icu.text.BreakIterator breakIterator = android.icu.text.BreakIterator.getCharacterInstance();
public AndroidIcuBreakIterator(@NonNull String string) {
breakIterator.setText(string);
}
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public boolean isDone(int index) {
return index == android.icu.text.BreakIterator.DONE;
}
}
/**
* An BreakIteratorCompat implementation that delegates calls to `java.text.BreakIterator`.
* This class may or may not handle grapheme clusters well depending on the underlying implementation.
* In the emulator, API 23 implements ICU version of the BreakIterator so that it handles grapheme
* clusters fine. But API 21 implements RuleBasedIterator which does not handle grapheme clusters.
* <p>
* If it doesn't handle grapheme clusters correctly, in most cases the combined characters are
* broken up into pieces when the code tries to trim a string. For example, an emoji that is
* a combination of a person, gender and skin tone, trimming the character using this class may result
* in trimming the parts of the character, e.g. a dark skin frowning woman emoji may result in
* a neutral skin frowning woman emoji.
*/
private static class FallbackBreakIterator implements BreakIteratorCompat {
private final java.text.BreakIterator breakIterator = java.text.BreakIterator.getCharacterInstance();
public FallbackBreakIterator(@NonNull String string) {
breakIterator.setText(string);
}
@Override
public int first() {
return breakIterator.first();
}
@Override
public int next() {
return breakIterator.next();
}
@Override
public boolean isDone(int index) {
return index == java.text.BreakIterator.DONE;
}
}
}

View file

@ -0,0 +1,31 @@
package org.signal.core.util
import java.util.Collections
/**
* Flattens a List of Map<K, V> into a Map<K, V> using the + operator.
*
* @return A Map containing all of the K, V pairings of the maps contained in the original list.
*/
fun <K, V> List<Map<K, V>>.flatten(): Map<K, V> = foldRight(emptyMap()) { a, b -> a + b }
/**
* Swaps the elements at the specified positions and returns the result in a new immutable list.
*
* @param i the index of one element to be swapped.
* @param j the index of the other element to be swapped.
*
* @throws IndexOutOfBoundsException if either i or j is out of range.
*/
fun <E> List<E>.swap(i: Int, j: Int): List<E> {
val mutableCopy = this.toMutableList()
Collections.swap(mutableCopy, i, j)
return mutableCopy.toList()
}
/**
* Returns the item wrapped in a list, or an empty list of the item is null.
*/
fun <E> E?.asList(): List<E> {
return if (this == null) emptyList() else listOf(this)
}

View file

@ -0,0 +1,22 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.content.ContentResolver
import android.net.Uri
import android.provider.OpenableColumns
import okio.IOException
@Throws(IOException::class)
fun ContentResolver.getLength(uri: Uri): Long? {
return this.query(uri, arrayOf(OpenableColumns.SIZE), null, null, null)?.use { cursor ->
if (cursor.moveToFirst()) {
cursor.requireLongOrNull(OpenableColumns.SIZE)
} else {
null
}
} ?: openInputStream(uri)?.use { it.readLength() }
}

View file

@ -0,0 +1,13 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.DownloadManager
import android.content.Context
fun Context.getDownloadManager(): DownloadManager {
return this.getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager
}

View file

@ -0,0 +1,187 @@
/**
* Copyright (C) 2014 Open Whisper Systems
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.signal.core.util;
public class Conversions {
public static byte intsToByteHighAndLow(int highValue, int lowValue) {
return (byte)((highValue << 4 | lowValue) & 0xFF);
}
public static int highBitsToInt(byte value) {
return (value & 0xFF) >> 4;
}
public static int lowBitsToInt(byte value) {
return (value & 0xF);
}
public static int highBitsToMedium(int value) {
return (value >> 12);
}
public static int lowBitsToMedium(int value) {
return (value & 0xFFF);
}
public static byte[] shortToByteArray(int value) {
byte[] bytes = new byte[2];
shortToByteArray(bytes, 0, value);
return bytes;
}
public static int shortToByteArray(byte[] bytes, int offset, int value) {
bytes[offset+1] = (byte)value;
bytes[offset] = (byte)(value >> 8);
return 2;
}
public static int shortToLittleEndianByteArray(byte[] bytes, int offset, int value) {
bytes[offset] = (byte)value;
bytes[offset+1] = (byte)(value >> 8);
return 2;
}
public static byte[] mediumToByteArray(int value) {
byte[] bytes = new byte[3];
mediumToByteArray(bytes, 0, value);
return bytes;
}
public static int mediumToByteArray(byte[] bytes, int offset, int value) {
bytes[offset + 2] = (byte)value;
bytes[offset + 1] = (byte)(value >> 8);
bytes[offset] = (byte)(value >> 16);
return 3;
}
public static byte[] intToByteArray(int value) {
byte[] bytes = new byte[4];
intToByteArray(bytes, 0, value);
return bytes;
}
public static int intToByteArray(byte[] bytes, int offset, int value) {
bytes[offset + 3] = (byte)value;
bytes[offset + 2] = (byte)(value >> 8);
bytes[offset + 1] = (byte)(value >> 16);
bytes[offset] = (byte)(value >> 24);
return 4;
}
public static int intToLittleEndianByteArray(byte[] bytes, int offset, int value) {
bytes[offset] = (byte)value;
bytes[offset+1] = (byte)(value >> 8);
bytes[offset+2] = (byte)(value >> 16);
bytes[offset+3] = (byte)(value >> 24);
return 4;
}
public static byte[] longToByteArray(long l) {
byte[] bytes = new byte[8];
longToByteArray(bytes, 0, l);
return bytes;
}
public static int longToByteArray(byte[] bytes, int offset, long value) {
bytes[offset + 7] = (byte)value;
bytes[offset + 6] = (byte)(value >> 8);
bytes[offset + 5] = (byte)(value >> 16);
bytes[offset + 4] = (byte)(value >> 24);
bytes[offset + 3] = (byte)(value >> 32);
bytes[offset + 2] = (byte)(value >> 40);
bytes[offset + 1] = (byte)(value >> 48);
bytes[offset] = (byte)(value >> 56);
return 8;
}
public static int longTo4ByteArray(byte[] bytes, int offset, long value) {
bytes[offset + 3] = (byte)value;
bytes[offset + 2] = (byte)(value >> 8);
bytes[offset + 1] = (byte)(value >> 16);
bytes[offset + 0] = (byte)(value >> 24);
return 4;
}
public static int byteArrayToShort(byte[] bytes) {
return byteArrayToShort(bytes, 0);
}
public static int byteArrayToShort(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 8 | (bytes[offset + 1] & 0xff);
}
// The SSL patented 3-byte Value.
public static int byteArrayToMedium(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 16 |
(bytes[offset + 1] & 0xff) << 8 |
(bytes[offset + 2] & 0xff);
}
public static int byteArrayToInt(byte[] bytes) {
return byteArrayToInt(bytes, 0);
}
public static int byteArrayToInt(byte[] bytes, int offset) {
return
(bytes[offset] & 0xff) << 24 |
(bytes[offset + 1] & 0xff) << 16 |
(bytes[offset + 2] & 0xff) << 8 |
(bytes[offset + 3] & 0xff);
}
public static int byteArrayToIntLittleEndian(byte[] bytes, int offset) {
return
(bytes[offset + 3] & 0xff) << 24 |
(bytes[offset + 2] & 0xff) << 16 |
(bytes[offset + 1] & 0xff) << 8 |
(bytes[offset] & 0xff);
}
public static long byteArrayToLong(byte[] bytes) {
return byteArrayToLong(bytes, 0);
}
public static long byteArray4ToLong(byte[] bytes, int offset) {
return
((bytes[offset + 0] & 0xffL) << 24) |
((bytes[offset + 1] & 0xffL) << 16) |
((bytes[offset + 2] & 0xffL) << 8) |
((bytes[offset + 3] & 0xffL));
}
public static long byteArrayToLong(byte[] bytes, int offset) {
return
((bytes[offset] & 0xffL) << 56) |
((bytes[offset + 1] & 0xffL) << 48) |
((bytes[offset + 2] & 0xffL) << 40) |
((bytes[offset + 3] & 0xffL) << 32) |
((bytes[offset + 4] & 0xffL) << 24) |
((bytes[offset + 5] & 0xffL) << 16) |
((bytes[offset + 6] & 0xffL) << 8) |
((bytes[offset + 7] & 0xffL));
}
public static int toIntExact(long value) {
if ((int)value != value) {
throw new ArithmeticException("integer overflow");
}
return (int)value;
}
}

View file

@ -0,0 +1,273 @@
package org.signal.core.util
import android.database.Cursor
import androidx.core.database.getIntOrNull
import androidx.core.database.getLongOrNull
import androidx.core.database.getStringOrNull
import java.util.Optional
fun Cursor.requireString(column: String): String? {
return CursorUtil.requireString(this, column)
}
fun Cursor.requireNonNullString(column: String): String {
return CursorUtil.requireString(this, column)!!
}
fun Cursor.optionalString(column: String): Optional<String> {
return CursorUtil.getString(this, column)
}
fun Cursor.requireInt(column: String): Int {
return CursorUtil.requireInt(this, column)
}
fun Cursor.requireIntOrNull(column: String): Int? {
return this.getIntOrNull(this.getColumnIndexOrThrow(column))
}
fun Cursor.optionalInt(column: String): Optional<Int> {
return CursorUtil.getInt(this, column)
}
fun Cursor.requireFloat(column: String): Float {
return CursorUtil.requireFloat(this, column)
}
fun Cursor.requireLong(column: String): Long {
return CursorUtil.requireLong(this, column)
}
fun Cursor.requireLongOrNull(column: String): Long? {
return this.getLongOrNull(this.getColumnIndexOrThrow(column))
}
fun Cursor.optionalLong(column: String): Optional<Long> {
return CursorUtil.getLong(this, column)
}
fun Cursor.requireBoolean(column: String): Boolean {
return CursorUtil.requireInt(this, column) != 0
}
fun Cursor.optionalBoolean(column: String): Optional<Boolean> {
return CursorUtil.getBoolean(this, column)
}
fun Cursor.requireBlob(column: String): ByteArray? {
return CursorUtil.requireBlob(this, column)
}
fun Cursor.requireNonNullBlob(column: String): ByteArray {
return CursorUtil.requireBlob(this, column)!!
}
fun Cursor.optionalBlob(column: String): Optional<ByteArray> {
return CursorUtil.getBlob(this, column)
}
fun Cursor.isNull(column: String): Boolean {
return CursorUtil.isNull(this, column)
}
fun <T> Cursor.requireObject(column: String, serializer: LongSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireLong(this, column))
}
fun <T> Cursor.requireObject(column: String, serializer: StringSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireString(this, column))
}
fun <T> Cursor.requireObject(column: String, serializer: IntSerializer<T>): T {
return serializer.deserialize(CursorUtil.requireInt(this, column))
}
@JvmOverloads
fun Cursor.readToSingleLong(defaultValue: Long = 0): Long {
return readToSingleLongOrNull() ?: defaultValue
}
fun Cursor.readToSingleLongOrNull(): Long? {
return use {
if (it.moveToFirst()) {
it.getLongOrNull(0)
} else {
null
}
}
}
fun <T> Cursor.readToSingleObject(serializer: BaseSerializer<T, Cursor, *>): T? {
return use {
if (it.moveToFirst()) {
serializer.deserialize(it)
} else {
null
}
}
}
fun <T> Cursor.readToSingleObject(mapper: (Cursor) -> T): T? {
return use {
if (it.moveToFirst()) {
mapper(it)
} else {
null
}
}
}
@JvmOverloads
fun Cursor.readToSingleInt(defaultValue: Int = 0): Int {
return use {
if (it.moveToFirst()) {
it.getInt(0)
} else {
defaultValue
}
}
}
fun Cursor.readToSingleIntOrNull(): Int? {
return use {
if (it.moveToFirst()) {
it.getIntOrNull(0)
} else {
null
}
}
}
fun Cursor.readToSingleBoolean(defaultValue: Boolean = false): Boolean {
return use {
if (it.moveToFirst()) {
it.getInt(0) != 0
} else {
defaultValue
}
}
}
@JvmOverloads
inline fun <T> Cursor.readToList(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): List<T> {
val list = mutableListOf<T>()
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
list += mapper(this)
}
}
}
return list
}
@JvmOverloads
inline fun <K, V> Cursor.readToMap(predicate: (Pair<K, V>) -> Boolean = { true }, mapper: (Cursor) -> Pair<K, V>): Map<K, V> {
return readToList(predicate, mapper).associate { it }
}
/**
* Groups the cursor by the given key, and returns a map of keys to lists of values.
*/
inline fun <K, V> Cursor.groupBy(mapper: (Cursor) -> Pair<K, V>): Map<K, List<V>> {
val map: MutableMap<K, MutableList<V>> = mutableMapOf()
use {
while (moveToNext()) {
val pair = mapper(this)
val list = map.getOrPut(pair.first) { mutableListOf() }
list += pair.second
}
}
return map
}
inline fun <T> Cursor.readToSet(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): Set<T> {
val set = mutableSetOf<T>()
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
set += mapper(this)
}
}
}
return set
}
inline fun <T> Cursor.firstOrNull(predicate: (T) -> Boolean = { true }, mapper: (Cursor) -> T): T? {
use {
while (moveToNext()) {
val record = mapper(this)
if (predicate(record)) {
return record
}
}
}
return null
}
inline fun Cursor.forEach(operation: (Cursor) -> Unit) {
use {
while (moveToNext()) {
operation(this)
}
}
}
inline fun Cursor.forEachIndexed(operation: (Int, Cursor) -> Unit) {
use {
var i = 0
while (moveToNext()) {
operation(i++, this)
}
}
}
fun Cursor.iterable(): Iterable<Cursor> {
return CursorIterable(this)
}
fun Boolean.toInt(): Int = if (this) 1 else 0
/**
* Renders the entire cursor row as a string.
* Not necessarily used in the app, but very useful to have available when debugging.
*/
fun Cursor.rowToString(): String {
val builder = StringBuilder()
for (i in 0 until this.columnCount) {
builder
.append(this.getColumnName(i))
.append("=")
.append(this.getStringOrNull(i))
if (i < this.columnCount - 1) {
builder.append(", ")
}
}
return builder.toString()
}
private class CursorIterable(private val cursor: Cursor) : Iterable<Cursor> {
override fun iterator(): Iterator<Cursor> {
return CursorIterator(cursor)
}
}
private class CursorIterator(private val cursor: Cursor) : Iterator<Cursor> {
override fun hasNext(): Boolean {
return !cursor.isClosed && cursor.count > 0 && !cursor.isLast && !cursor.isAfterLast
}
override fun next(): Cursor {
return if (cursor.moveToNext()) {
cursor
} else {
throw NoSuchElementException()
}
}
}

View file

@ -0,0 +1,107 @@
package org.signal.core.util;
import android.database.Cursor;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import java.util.Optional;
import java.util.function.Function;
public final class CursorUtil {
private CursorUtil() {}
public static String requireString(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getString(cursor.getColumnIndexOrThrow(column));
}
public static int requireInt(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getInt(cursor.getColumnIndexOrThrow(column));
}
public static float requireFloat(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getFloat(cursor.getColumnIndexOrThrow(column));
}
public static long requireLong(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getLong(cursor.getColumnIndexOrThrow(column));
}
public static boolean requireBoolean(@NonNull Cursor cursor, @NonNull String column) {
return requireInt(cursor, column) != 0;
}
public static byte[] requireBlob(@NonNull Cursor cursor, @NonNull String column) {
return cursor.getBlob(cursor.getColumnIndexOrThrow(column));
}
public static boolean isNull(@NonNull Cursor cursor, @NonNull String column) {
return cursor.isNull(cursor.getColumnIndexOrThrow(column));
}
public static boolean requireMaskedBoolean(@NonNull Cursor cursor, @NonNull String column, int position) {
return Bitmask.read(requireLong(cursor, column), position);
}
public static int requireMaskedInt(@NonNull Cursor cursor, @NonNull String column, int position, int flagBitSize) {
return Conversions.toIntExact(Bitmask.read(requireLong(cursor, column), position, flagBitSize));
}
public static Optional<String> getString(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.ofNullable(requireString(cursor, column));
}
}
public static Optional<Integer> getInt(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireInt(cursor, column));
}
}
public static Optional<Long> getLong(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireLong(cursor, column));
}
}
public static Optional<Boolean> getBoolean(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.of(requireBoolean(cursor, column));
}
}
public static Optional<byte[]> getBlob(@NonNull Cursor cursor, @NonNull String column) {
if (cursor.getColumnIndex(column) < 0) {
return Optional.empty();
} else {
return Optional.ofNullable(requireBlob(cursor, column));
}
}
/**
* Reads each column as a string, and concatenates them together into a single string separated by |
*/
public static String readRowAsString(@NonNull Cursor cursor) {
StringBuilder row = new StringBuilder();
for (int i = 0, len = cursor.getColumnCount(); i < len; i++) {
row.append(cursor.getString(i));
if (i < len - 1) {
row.append(" | ");
}
}
return row.toString();
}
}

View file

@ -0,0 +1,7 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
public interface DatabaseId {
@NonNull String serialize();
}

View file

@ -0,0 +1,73 @@
package org.signal.core.util;
import android.content.res.Resources;
import androidx.annotation.Dimension;
import androidx.annotation.Px;
/**
* Core utility for converting different dimensional values.
*/
public enum DimensionUnit {
PIXELS {
@Override
@Px
public float toPixels(@Px float pixels) {
return pixels;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Px float pixels) {
return pixels / Resources.getSystem().getDisplayMetrics().density;
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Px float pixels) {
return pixels / Resources.getSystem().getDisplayMetrics().scaledDensity;
}
},
DP {
@Override
@Px
public float toPixels(@Dimension(unit = Dimension.DP) float dp) {
return dp * Resources.getSystem().getDisplayMetrics().density;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Dimension(unit = Dimension.DP) float dp) {
return dp;
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Dimension(unit = Dimension.DP) float dp) {
return PIXELS.toSp(toPixels(dp));
}
},
SP {
@Override
@Px
public float toPixels(@Dimension(unit = Dimension.SP) float sp) {
return sp * Resources.getSystem().getDisplayMetrics().scaledDensity;
}
@Override
@Dimension(unit = Dimension.DP)
public float toDp(@Dimension(unit = Dimension.SP) float sp) {
return PIXELS.toDp(toPixels(sp));
}
@Override
@Dimension(unit = Dimension.SP)
public float toSp(@Dimension(unit = Dimension.SP) float sp) {
return sp;
}
};
public abstract float toPixels(float value);
public abstract float toDp(float value);
public abstract float toSp(float value);
}

View file

@ -0,0 +1,27 @@
package org.signal.core.util
import androidx.annotation.Px
/**
* Converts the given Float DP value into Pixels.
*/
@get:Px
val Float.dp: Float get() = DimensionUnit.DP.toPixels(this)
/**
* Converts the given Int DP value into Pixels
*/
@get:Px
val Int.dp: Int get() = this.toFloat().dp.toInt()
/**
* Converts the given Float SP value into Pixels.
*/
@get:Px
val Float.sp: Float get() = DimensionUnit.SP.toPixels(this)
/**
* Converts the given Int SP value into Pixels
*/
@get:Px
val Int.sp: Int get() = this.toFloat().sp.toInt()

View file

@ -0,0 +1,72 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.usage.StorageStatsManager
import android.content.Context
import android.os.Build
import android.os.StatFs
import android.os.storage.StorageManager
import androidx.annotation.RequiresApi
object DiskUtil {
/**
* Gets the remaining storage usable by the application.
*
* @param context The application context
*/
@JvmStatic
fun getAvailableSpace(context: Context): ByteSize {
return if (Build.VERSION.SDK_INT >= 26) {
getAvailableStorageBytesApi26(context).bytes
} else {
return getAvailableStorageBytesLegacy(context).bytes
}
}
/**
* Gets the total disk size of the volume used by the application.
*
* @param context The application context
*/
@JvmStatic
fun getTotalDiskSize(context: Context): ByteSize {
return if (Build.VERSION.SDK_INT >= 26) {
getTotalDiskSizeApi26(context).bytes
} else {
return getTotalDiskSizeLegacy(context).bytes
}
}
@RequiresApi(26)
private fun getAvailableStorageBytesApi26(context: Context): Long {
val storageManager = context.getSystemService(Context.STORAGE_SERVICE) as StorageManager
val storageStatsManager = context.getSystemService(Context.STORAGE_STATS_SERVICE) as StorageStatsManager
val appStorageUuid = storageManager.getUuidForPath(context.filesDir)
return storageStatsManager.getFreeBytes(appStorageUuid)
}
private fun getAvailableStorageBytesLegacy(context: Context): Long {
val stat = StatFs(context.filesDir.absolutePath)
return stat.availableBytes
}
@RequiresApi(26)
private fun getTotalDiskSizeApi26(context: Context): Long {
val storageManager = context.getSystemService(Context.STORAGE_SERVICE) as StorageManager
val storageStatsManager = context.getSystemService(Context.STORAGE_STATS_SERVICE) as StorageStatsManager
val appStorageUuid = storageManager.getUuidForPath(context.filesDir)
return storageStatsManager.getTotalBytes(appStorageUuid)
}
private fun getTotalDiskSizeLegacy(context: Context): Long {
val stat = StatFs(context.filesDir.absolutePath)
return stat.totalBytes
}
}

View file

@ -0,0 +1,78 @@
package org.signal.core.util;
import android.annotation.SuppressLint;
import android.annotation.TargetApi;
import android.graphics.PorterDuff;
import android.graphics.PorterDuffColorFilter;
import android.graphics.drawable.Drawable;
import android.os.Build;
import android.text.InputFilter;
import android.widget.EditText;
import android.widget.TextView;
import androidx.annotation.ColorInt;
import androidx.annotation.NonNull;
import androidx.annotation.RequiresApi;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public final class EditTextUtil {
private EditTextUtil() {
}
public static void addGraphemeClusterLimitFilter(EditText text, int maximumGraphemes) {
List<InputFilter> filters = new ArrayList<>(Arrays.asList(text.getFilters()));
filters.add(new GraphemeClusterLimitFilter(maximumGraphemes));
text.setFilters(filters.toArray(new InputFilter[0]));
}
public static void setCursorColor(@NonNull EditText text, @ColorInt int colorInt) {
if (Build.VERSION.SDK_INT >= 29) {
Drawable drawable = text.getTextCursorDrawable();
if (drawable == null) {
return;
}
Drawable cursorDrawable = drawable.mutate();
cursorDrawable.setColorFilter(new PorterDuffColorFilter(colorInt, PorterDuff.Mode.SRC_IN));
text.setTextCursorDrawable(cursorDrawable);
} else {
setCursorColorViaReflection(text, colorInt);
}
}
/**
* Note: This is only ever called in API 28 and less.
*/
@SuppressLint("SoonBlockedPrivateApi")
private static void setCursorColorViaReflection(EditText editText, int color) {
try {
Field fCursorDrawableRes = TextView.class.getDeclaredField("mCursorDrawableRes");
fCursorDrawableRes.setAccessible(true);
int mCursorDrawableRes = fCursorDrawableRes.getInt(editText);
Field fEditor = TextView.class.getDeclaredField("mEditor");
fEditor.setAccessible(true);
Object editor = fEditor.get(editText);
Class<?> clazz = editor.getClass();
Field fCursorDrawable = clazz.getDeclaredField("mCursorDrawable");
fCursorDrawable.setAccessible(true);
Drawable[] drawables = new Drawable[2];
drawables[0] = editText.getContext().getResources().getDrawable(mCursorDrawableRes);
drawables[1] = editText.getContext().getResources().getDrawable(mCursorDrawableRes);
drawables[0].setColorFilter(color, PorterDuff.Mode.SRC_IN);
drawables[1].setColorFilter(color, PorterDuff.Mode.SRC_IN);
fCursorDrawable.set(editor, drawables);
} catch (Throwable ignored) {
}
}
}

View file

@ -0,0 +1,114 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlin.math.ceil
import kotlin.math.floor
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.DurationUnit
/**
* Used to track performance metrics for large clusters of similar events.
* For instance, if you were doing a backup restore and had to important many different kinds of data in an unknown order, you could
* use this to learn stats around how long each kind of data takes to import.
*
* It is assumed that all events are happening serially with no delays in between.
*
* The timer tracks things at nanosecond granularity, but presents data as fractional milliseconds for readability.
*/
class EventTimer {
private val durationsByGroup: MutableMap<String, MutableList<Long>> = mutableMapOf()
private var startTime = System.nanoTime()
private var lastTimeNanos: Long = startTime
fun reset() {
startTime = System.nanoTime()
lastTimeNanos = startTime
durationsByGroup.clear()
}
/**
* Indicates an event in the specified group has finished.
*/
fun emit(group: String) {
val now = System.nanoTime()
val duration = now - lastTimeNanos
durationsByGroup.getOrPut(group) { mutableListOf() } += duration
lastTimeNanos = now
}
/**
* Stops the timer and returns a mapping of group -> [EventMetrics], which will tell you various statistics around timings for that group.
*/
fun stop(): EventTimerResults {
val data: Map<String, EventMetrics> = durationsByGroup
.mapValues { entry ->
val sorted: List<Long> = entry.value.sorted()
EventMetrics(
totalTime = sorted.sum().nanoseconds.toDouble(DurationUnit.MILLISECONDS),
eventCount = sorted.size,
sortedDurationNanos = sorted
)
}
return EventTimerResults(data)
}
class EventTimerResults(data: Map<String, EventMetrics>) : Map<String, EventMetrics> by data {
val summary by lazy {
val builder = StringBuilder()
builder.append("[overall] totalTime: ${data.values.map { it.totalTime }.sum().roundedString(2)} ")
for (entry in data) {
builder.append("[${entry.key}] totalTime: ${entry.value.totalTime.roundedString(2)}, count: ${entry.value.eventCount}, p50: ${entry.value.p(50)}, p90: ${entry.value.p(90)}, p99: ${entry.value.p(99)} ")
}
builder.toString()
}
}
data class EventMetrics(
/** The sum of all event durations, in fractional milliseconds. */
val totalTime: Double,
/** Total number of events observed. */
val eventCount: Int,
private val sortedDurationNanos: List<Long>
) {
/**
* Returns the percentile of the duration data (e.g. p50, p90) as a formatted string containing fractional milliseconds rounded to the requested number of decimal places.
*/
fun p(percentile: Int, decimalPlaces: Int = 2): String {
return pNanos(percentile).nanoseconds.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)
}
private fun pNanos(percentile: Int): Long {
if (sortedDurationNanos.isEmpty()) {
return 0L
}
val index: Float = (percentile / 100f) * (sortedDurationNanos.size - 1)
val lowerIndex: Int = floor(index).toInt()
val upperIndex: Int = ceil(index).toInt()
if (lowerIndex == upperIndex) {
return sortedDurationNanos[lowerIndex]
}
val interpolationFactor: Float = index - lowerIndex
val lowerValue: Long = sortedDurationNanos[lowerIndex]
val upperValue: Long = sortedDurationNanos[upperIndex]
return floor(lowerValue + (upperValue - lowerValue) * interpolationFactor).toLong()
}
}
}

View file

@ -0,0 +1,96 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
import org.signal.core.util.logging.Scrubber;
import java.io.ByteArrayOutputStream;
import java.io.PrintStream;
public final class ExceptionUtil {
private ExceptionUtil() {}
/**
* Joins the stack trace of the inferred call site with the original exception. This is
* useful for when exceptions are thrown inside of asynchronous systems (like runnables in an
* executor) where you'd otherwise lose important parts of the stack trace. This lets you save a
* throwable at the entry point, and then combine it with any caught exceptions later.
*
* The resulting stack trace will look like this:
*
* Inferred
* Stack
* Trace
* [[ Inferred Trace ]]
* [[ Original Trace ]]
* Original
* Stack
* Trace
*
* @return The provided original exception, for convenience.
*/
public static <E extends Throwable> E joinStackTrace(@NonNull E original, @NonNull Throwable inferred) {
StackTraceElement[] combinedTrace = joinStackTrace(original.getStackTrace(), inferred.getStackTrace());
original.setStackTrace(combinedTrace);
return original;
}
/**
* See {@link #joinStackTrace(Throwable, Throwable)}
*/
public static StackTraceElement[] joinStackTrace(@NonNull StackTraceElement[] originalTrace, @NonNull StackTraceElement[] inferredTrace) {
StackTraceElement[] combinedTrace = new StackTraceElement[originalTrace.length + inferredTrace.length + 2];
System.arraycopy(originalTrace, 0, combinedTrace, 0, originalTrace.length);
combinedTrace[originalTrace.length] = new StackTraceElement("[[ ↑↑ Original Trace ↑↑ ]]", "", "", 0);
combinedTrace[originalTrace.length + 1] = new StackTraceElement("[[ ↓↓ Inferred Trace ↓↓ ]]", "", "", 0);
System.arraycopy(inferredTrace, 0, combinedTrace, originalTrace.length + 2, inferredTrace.length);
return combinedTrace;
}
/**
* Joins the stack trace with the exception's {@link Throwable#getMessage()}.
*
* The resulting stack trace will look like this:
*
* Original
* Stack
* Trace
* [[ Original Trace ]]
* [[ Exception Message ]]
* Exception Message
*
* @return The provided original exception, for convenience.
*/
public static @NonNull <E extends Throwable> E joinStackTraceAndMessage(@NonNull E original) {
StackTraceElement[] originalTrace = original.getStackTrace();
StackTraceElement[] combinedTrace = new StackTraceElement[originalTrace.length + 3];
System.arraycopy(originalTrace, 0, combinedTrace, 0, originalTrace.length);
String message = Scrubber.scrub(original.getMessage() != null ? original.getMessage() : "null").toString();
if (message.startsWith("Context.startForegroundService")) {
try {
String service = message.substring(message.lastIndexOf('.') + 1, message.length() - 1);
message = service + " did not call startForeground";
} catch (Exception ignored) {}
}
combinedTrace[originalTrace.length] = new StackTraceElement("[[ ↑↑ Original Trace ↑↑ ]]", "", "", 0);
combinedTrace[originalTrace.length + 1] = new StackTraceElement("[[ ↓↓ Exception Message ↓↓ ]]", "", "", 0);
combinedTrace[originalTrace.length + 2] = new StackTraceElement(message, "", "", 0);
original.setStackTrace(combinedTrace);
return original;
}
public static @NonNull String convertThrowableToString(@NonNull Throwable throwable) {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
throwable.printStackTrace(new PrintStream(outputStream));
return outputStream.toString();
}
}

View file

@ -0,0 +1,39 @@
package org.signal.core.util
import android.graphics.Bitmap
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.PorterDuff
import kotlin.math.abs
object FontUtil {
private const val SAMPLE_EMOJI = "\uD83C\uDF0D" // 🌍
/**
* Certain platforms cannot render emoji above a certain font size.
*
* This will attempt to render an emoji at the specified font size and tell you if it's possible.
* It does this by rendering an emoji into a 1x1 bitmap and seeing if the resulting pixel is non-transparent.
*
* https://stackoverflow.com/a/50988748
*/
@JvmStatic
fun canRenderEmojiAtFontSize(size: Float): Boolean {
val bitmap: Bitmap = Bitmap.createBitmap(1, 1, Bitmap.Config.ARGB_8888)
val canvas = Canvas(bitmap)
val paint = Paint()
paint.textSize = size
paint.textAlign = Paint.Align.CENTER
val ascent: Float = abs(paint.ascent())
val descent: Float = abs(paint.descent())
val halfHeight = (ascent + descent) / 2.0f
canvas.drawColor(Color.TRANSPARENT, PorterDuff.Mode.CLEAR)
canvas.drawText(SAMPLE_EMOJI, 0.5f, 0.5f + halfHeight - descent, paint)
return bitmap.getPixel(0, 0) != 0
}
}

View file

@ -0,0 +1,54 @@
package org.signal.core.util;
import android.text.InputFilter;
import android.text.Spanned;
import org.signal.core.util.logging.Log;
/**
* This filter will constrain edits not to make the number of character breaks of the text
* greater than the specified maximum.
* <p>
* This means it will limit to a maximum number of grapheme clusters.
*/
public final class GraphemeClusterLimitFilter implements InputFilter {
private static final String TAG = Log.tag(GraphemeClusterLimitFilter.class);
private final BreakIteratorCompat breakIteratorCompat;
private final int max;
public GraphemeClusterLimitFilter(int max) {
this.breakIteratorCompat = BreakIteratorCompat.getInstance();
this.max = max;
}
@Override
public CharSequence filter(CharSequence source, int start, int end, Spanned dest, int dstart, int dend) {
CharSequence sourceFragment = source.subSequence(start, end);
CharSequence head = dest.subSequence(0, dstart);
CharSequence tail = dest.subSequence(dend, dest.length());
breakIteratorCompat.setText(String.format("%s%s%s", head, sourceFragment, tail));
int length = breakIteratorCompat.countBreaks();
if (length > max) {
breakIteratorCompat.setText(sourceFragment);
int sourceLength = breakIteratorCompat.countBreaks();
CharSequence trimmedSource = breakIteratorCompat.take(sourceLength - (length - max));
breakIteratorCompat.setText(String.format("%s%s%s", head, trimmedSource, tail));
int newExpectedCount = breakIteratorCompat.countBreaks();
if (newExpectedCount > max) {
Log.w(TAG, "Failed to create string under the required length " + newExpectedCount);
return "";
}
return trimmedSource;
}
return source;
}
}

View file

@ -0,0 +1,23 @@
package org.signal.core.util
import android.content.Intent
import android.os.Build
import android.os.Parcelable
fun <T : Parcelable> Intent.getParcelableExtraCompat(key: String, clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableExtra(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableExtra(key)
}
}
fun <T : Parcelable> Intent.getParcelableArrayListExtraCompat(key: String, clazz: Class<T>): ArrayList<T>? {
return if (Build.VERSION.SDK_INT >= 33) {
this.getParcelableArrayListExtra(key, clazz)
} else {
@Suppress("DEPRECATION")
this.getParcelableArrayListExtra(key)
}
}

View file

@ -0,0 +1,23 @@
package org.signal.core.util;
import java.util.concurrent.LinkedBlockingDeque;
public class LinkedBlockingLifoQueue<E> extends LinkedBlockingDeque<E> {
@Override
public void put(E runnable) throws InterruptedException {
super.putFirst(runnable);
}
@Override
public boolean add(E runnable) {
super.addFirst(runnable);
return true;
}
@Override
public boolean offer(E runnable) {
super.addFirst(runnable);
return true;
}
}

View file

@ -0,0 +1,39 @@
package org.signal.core.util;
import androidx.annotation.NonNull;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;
public final class ListUtil {
private ListUtil() {}
public static <E> List<List<E>> chunk(@NonNull List<E> list, int chunkSize) {
List<List<E>> chunks = new ArrayList<>(list.size() / chunkSize);
for (int i = 0; i < list.size(); i += chunkSize) {
List<E> chunk = list.subList(i, Math.min(list.size(), i + chunkSize));
chunks.add(chunk);
}
return chunks;
}
@SafeVarargs
public static <T> List<T> concat(Collection<T>... items) {
final List<T> concat = new ArrayList<>(Stream.of(items).map(Collection::size).reduce(0, Integer::sum));
for (Collection<T> list : items) {
concat.addAll(list);
}
return concat;
}
public static <T> List<T> emptyIfNull(List<T> list) {
return list == null ? Collections.emptyList() : list;
}
}

View file

@ -0,0 +1,30 @@
package org.signal.core.util;
import android.os.Build;
import androidx.annotation.NonNull;
import java.util.Map;
import java.util.function.Function;
public final class MapUtil {
private MapUtil() {}
@NonNull
public static <K, V> V getOrDefault(@NonNull Map<K, V> map, @NonNull K key, @NonNull V defaultValue) {
if (Build.VERSION.SDK_INT >= 24) {
//noinspection ConstantConditions
return map.getOrDefault(key, defaultValue);
} else {
V v = map.get(key);
return v == null ? defaultValue : v;
}
}
@NonNull
public static <K, V, M> M mapOrDefault(@NonNull Map<K, V> map, @NonNull K key, @NonNull Function<V, M> mapper, @NonNull M defaultValue) {
V v = map.get(key);
return v == null ? defaultValue : mapper.apply(v);
}
}

View file

@ -0,0 +1,151 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.ActivityManager
import android.content.Context
import android.os.Debug
import android.os.Handler
import org.signal.core.util.concurrent.SignalExecutors
import org.signal.core.util.logging.Log
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
object MemoryTracker {
private val TAG = Log.tag(MemoryTracker::class.java)
private val runtime: Runtime = Runtime.getRuntime()
private val activityMemoryInfo: ActivityManager.MemoryInfo = ActivityManager.MemoryInfo()
private val debugMemoryInfo: Debug.MemoryInfo = Debug.MemoryInfo()
private val handler: Handler = Handler(SignalExecutors.getAndStartHandlerThread("MemoryTracker", ThreadUtil.PRIORITY_BACKGROUND_THREAD).looper)
private val POLLING_INTERVAL = 5.seconds.inWholeMilliseconds
private var running = false
private lateinit var previousAppHeadUsage: AppHeapUsage
private var increaseMemoryCount = 0
@JvmStatic
fun start() {
Log.d(TAG, "Beginning memory monitoring.")
running = true
previousAppHeadUsage = getAppJvmHeapUsage()
increaseMemoryCount = 0
handler.postDelayed(this::poll, POLLING_INTERVAL)
}
@JvmStatic
fun stop() {
Log.d(TAG, "Ending memory monitoring.")
running = false
handler.removeCallbacksAndMessages(null)
}
fun poll() {
val currentHeapUsage = getAppJvmHeapUsage()
if (currentHeapUsage.currentTotalBytes != previousAppHeadUsage.currentTotalBytes) {
if (currentHeapUsage.currentTotalBytes > previousAppHeadUsage.currentTotalBytes) {
Log.d(TAG, "The system increased our app JVM heap from ${previousAppHeadUsage.currentTotalBytes.byteDisplay()} to ${currentHeapUsage.currentTotalBytes.byteDisplay()}")
} else {
Log.d(TAG, "The system decreased our app JVM heap from ${previousAppHeadUsage.currentTotalBytes.byteDisplay()} to ${currentHeapUsage.currentTotalBytes.byteDisplay()}")
}
}
if (currentHeapUsage.usedBytes >= previousAppHeadUsage.usedBytes) {
increaseMemoryCount++
} else {
Log.d(TAG, "Used memory has decreased from ${previousAppHeadUsage.usedBytes.byteDisplay()} to ${currentHeapUsage.usedBytes.byteDisplay()}")
increaseMemoryCount = 0
}
if (increaseMemoryCount > 0 && increaseMemoryCount % 5 == 0) {
Log.d(TAG, "Used memory has increased or stayed the same for the last $increaseMemoryCount intervals (${increaseMemoryCount * POLLING_INTERVAL.milliseconds.inWholeSeconds} seconds). Using: ${currentHeapUsage.usedBytes.byteDisplay()}, Free: ${currentHeapUsage.freeBytes.byteDisplay()}, CurrentTotal: ${currentHeapUsage.currentTotalBytes.byteDisplay()}, MaxPossible: ${currentHeapUsage.maxPossibleBytes.byteDisplay()}")
}
previousAppHeadUsage = currentHeapUsage
if (running) {
handler.postDelayed(this::poll, POLLING_INTERVAL)
}
}
/**
* Gives us basic memory usage data for our app JVM heap usage. Very fast, ~10 micros on an emulator.
*/
fun getAppJvmHeapUsage(): AppHeapUsage {
return AppHeapUsage(
freeBytes = runtime.freeMemory(),
currentTotalBytes = runtime.totalMemory(),
maxPossibleBytes = runtime.maxMemory()
)
}
/**
* This gives us details stats, but it takes an appreciable amount of time. On an emulator, it can take ~30ms.
* As a result, we don't want to be calling this regularly for most users.
*/
fun getDetailedMemoryStats(): DetailedMemoryStats {
Debug.getMemoryInfo(debugMemoryInfo)
return DetailedMemoryStats(
appJavaHeapUsageKb = debugMemoryInfo.getMemoryStat("summary.java-heap")?.toLongOrNull(),
appNativeHeapUsageKb = debugMemoryInfo.getMemoryStat("summary.native-heap")?.toLongOrNull(),
codeUsageKb = debugMemoryInfo.getMemoryStat("summary.code")?.toLongOrNull(),
stackUsageKb = debugMemoryInfo.getMemoryStat("summary.stack")?.toLongOrNull(),
graphicsUsageKb = debugMemoryInfo.getMemoryStat("summary.graphics")?.toLongOrNull(),
appOtherUsageKb = debugMemoryInfo.getMemoryStat("summary.private-other")?.toLongOrNull()
)
}
fun getSystemNativeMemoryUsage(context: Context): NativeMemoryUsage {
val activityManager: ActivityManager = context.getSystemService(Context.ACTIVITY_SERVICE) as ActivityManager
activityManager.getMemoryInfo(activityMemoryInfo)
return NativeMemoryUsage(
freeBytes = activityMemoryInfo.availMem,
totalBytes = activityMemoryInfo.totalMem,
lowMemory = activityMemoryInfo.lowMemory,
lowMemoryThreshold = activityMemoryInfo.threshold
)
}
private fun Long.byteDisplay(): String {
return "$this (${this.bytes.inMebiBytes.roundedString(2)} MiB)"
}
data class AppHeapUsage(
/** The number of bytes that are free to use. */
val freeBytes: Long,
/** The current total number of bytes our app could use. This can increase over time as the system increases our allocation. */
val currentTotalBytes: Long,
/** The maximum number of bytes that our app could ever be given. */
val maxPossibleBytes: Long
) {
/** The number of bytes that our app is currently using. */
val usedBytes: Long
get() = currentTotalBytes - freeBytes
}
data class NativeMemoryUsage(
val freeBytes: Long,
val totalBytes: Long,
val lowMemory: Boolean,
val lowMemoryThreshold: Long
) {
val usedBytes: Long
get() = totalBytes - freeBytes
}
data class DetailedMemoryStats(
val appJavaHeapUsageKb: Long?,
val appNativeHeapUsageKb: Long?,
val codeUsageKb: Long?,
val graphicsUsageKb: Long?,
val stackUsageKb: Long?,
val appOtherUsageKb: Long?
)
}

View file

@ -0,0 +1,129 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.util.Queue
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import kotlin.math.ceil
import kotlin.math.floor
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.DurationUnit
/**
* Used to track performance metrics for large clusters of similar events that are happening simultaneously.
*
* Very similar to [EventTimer], but with no assumptions around threading,
*
* The timer tracks things at nanosecond granularity, but presents data as fractional milliseconds for readability.
*/
class ParallelEventTimer {
val durationsByGroup: MutableMap<String, Queue<Long>> = ConcurrentHashMap()
private var startTime = System.nanoTime()
fun reset() {
durationsByGroup.clear()
startTime = System.nanoTime()
}
/**
* Begin an event associated with a group. You must call [EventStopper.stopEvent] on the returned object in order to indicate the action has completed.
*/
fun beginEvent(group: String): EventStopper {
val start = System.nanoTime()
return EventStopper {
val duration = System.nanoTime() - start
durationsByGroup.computeIfAbsent(group) { ConcurrentLinkedQueue() } += duration
}
}
/**
* Time an event associated with a group.
*/
inline fun <E> timeEvent(group: String, operation: () -> E): E {
val start = System.nanoTime()
val result = operation()
val duration = System.nanoTime() - start
durationsByGroup.computeIfAbsent(group) { ConcurrentLinkedQueue() } += duration
return result
}
/**
* Stops the timer and returns a mapping of group -> [EventMetrics], which will tell you various statistics around timings for that group.
* It is assumed that all events have been stopped by the time this has been called.
*/
fun stop(): EventTimerResults {
val totalDuration = System.nanoTime() - startTime
val data: Map<String, EventMetrics> = durationsByGroup
.mapValues { entry ->
val sorted: List<Long> = entry.value.sorted()
EventMetrics(
totalEventTime = sorted.sum().nanoseconds.toDouble(DurationUnit.MILLISECONDS),
eventCount = sorted.size,
sortedDurationNanos = sorted
)
}
return EventTimerResults(totalDuration.nanoseconds.toDouble(DurationUnit.MILLISECONDS), data)
}
class EventTimerResults(totalWallTime: Double, data: Map<String, EventMetrics>) : Map<String, EventMetrics> by data {
val summary by lazy {
val builder = StringBuilder()
builder.append("[overall] totalWallTime: ${totalWallTime.roundedString(2)}, totalEventTime: ${data.values.map { it.totalEventTime}.sum().roundedString(2)} ")
for (entry in data) {
builder.append("[${entry.key}] totalEventTime: ${entry.value.totalEventTime.roundedString(2)}, count: ${entry.value.eventCount}, p50: ${entry.value.p(50)}, p90: ${entry.value.p(90)}, p99: ${entry.value.p(99)} ")
}
builder.toString()
}
}
fun interface EventStopper {
fun stopEvent()
}
data class EventMetrics(
/** The sum of all event times, in fractional milliseconds. If running operations in parallel, this will likely be larger than [totalWallTime]. */
val totalEventTime: Double,
/** Total number of events observed. */
val eventCount: Int,
private val sortedDurationNanos: List<Long>
) {
/**
* Returns the percentile of the duration data (e.g. p50, p90) as a formatted string containing fractional milliseconds rounded to the requested number of decimal places.
*/
fun p(percentile: Int, decimalPlaces: Int = 2): String {
return pNanos(percentile).nanoseconds.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)
}
private fun pNanos(percentile: Int): Long {
if (sortedDurationNanos.isEmpty()) {
return 0L
}
val index: Float = (percentile / 100f) * (sortedDurationNanos.size - 1)
val lowerIndex: Int = floor(index).toInt()
val upperIndex: Int = ceil(index).toInt()
if (lowerIndex == upperIndex) {
return sortedDurationNanos[lowerIndex]
}
val interpolationFactor: Float = index - lowerIndex
val lowerValue: Long = sortedDurationNanos[lowerIndex]
val upperValue: Long = sortedDurationNanos[upperIndex]
return floor(lowerValue + (upperValue - lowerValue) * interpolationFactor).toLong()
}
}
}

View file

@ -0,0 +1,23 @@
package org.signal.core.util
import android.os.Build
import android.os.Parcel
import android.os.Parcelable
fun <T : Parcelable> Parcel.readParcelableCompat(clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.readParcelable(clazz.classLoader, clazz)
} else {
@Suppress("DEPRECATION")
this.readParcelable(clazz.classLoader)
}
}
fun <T : java.io.Serializable> Parcel.readSerializableCompat(clazz: Class<T>): T? {
return if (Build.VERSION.SDK_INT >= 33) {
this.readSerializable(clazz.classLoader, clazz)
} else {
@Suppress("DEPRECATION", "UNCHECKED_CAST")
this.readSerializable() as T
}
}

View file

@ -0,0 +1,47 @@
package org.signal.core.util
import android.app.PendingIntent
import android.os.Build
/**
* Wrapper class for lower level API compatibility with the new Pending Intents flags.
*
* This is meant to be a replacement to using PendingIntent flags independently, and should
* end up being the only place in our codebase that accesses these values.
*
* The "default" value is FLAG_MUTABLE
*/
object PendingIntentFlags {
@JvmStatic
fun updateCurrent(): Int {
return mutable() or PendingIntent.FLAG_UPDATE_CURRENT
}
@JvmStatic
fun cancelCurrent(): Int {
return mutable() or PendingIntent.FLAG_CANCEL_CURRENT
}
/**
* Flag indicating that this [PendingIntent] can be used only once. After [PendingIntent.send] is called on it,
* it will be automatically canceled for you and any future attempt to send through it will fail.
*/
@JvmStatic
fun oneShot(): Int {
return immutable() or PendingIntent.FLAG_ONE_SHOT
}
/**
* The backwards compatible "default" value for pending intent flags.
*/
@JvmStatic
fun mutable(): Int {
return if (Build.VERSION.SDK_INT >= 31) PendingIntent.FLAG_MUTABLE else 0
}
@JvmStatic
fun immutable(): Int {
return PendingIntent.FLAG_IMMUTABLE
}
}

View file

@ -0,0 +1,33 @@
package org.signal.core.util;
import android.content.Context;
import android.content.res.Configuration;
import android.content.res.Resources;
import androidx.annotation.NonNull;
import androidx.annotation.StringRes;
import java.util.Locale;
/**
* Gives access to English strings.
*/
public final class ResourceUtil {
private ResourceUtil() {
}
public static Resources getEnglishResources(@NonNull Context context) {
return getResources(context, Locale.ENGLISH);
}
public static Resources getResources(@NonNull Context context, @NonNull Locale locale) {
Configuration configurationLocal = context.getResources().getConfiguration();
Configuration configurationEn = new Configuration(configurationLocal);
configurationEn.setLocale(locale);
return context.createConfigurationContext(configurationEn)
.getResources();
}
}

View file

@ -0,0 +1,61 @@
package org.signal.core.util
/**
* A Result that allows for generic definitions of success/failure values.
*/
sealed class Result<out S, out F> {
data class Failure<out F>(val failure: F) : Result<Nothing, F>()
data class Success<out S>(val success: S) : Result<S, Nothing>()
companion object {
@JvmStatic
fun <S> success(value: S) = Success(value)
@JvmStatic
fun <F> failure(value: F) = Failure(value)
}
/**
* Maps an Result<S, F> to an Result<T, F>. Failure values will pass through, while
* right values will be operated on by the parameter.
*/
fun <T> map(onSuccess: (S) -> T): Result<T, F> {
return when (this) {
is Failure -> this
is Success -> success(onSuccess(success))
}
}
/**
* Allows the caller to operate on the Result such that the correct function is applied
* to the value it contains.
*/
fun <T> either(
onSuccess: (S) -> T,
onFailure: (F) -> T
): T {
return when (this) {
is Success -> onSuccess(success)
is Failure -> onFailure(failure)
}
}
}
/**
* Maps an Result<L, R> to an Result<L, T>. Failure values will pass through, while
* right values will be operated on by the parameter.
*
* Note this is an extension method in order to make the generics happy.
*/
fun <T, S, F> Result<S, F>.flatMap(onSuccess: (S) -> Result<T, F>): Result<T, F> {
return when (this) {
is Result.Success -> onSuccess(success)
is Result.Failure -> this
}
}
/**
* Try is a specialization of Result where the Failure is fixed to Throwable.
*/
typealias Try<S> = Result<S, Throwable>

View file

@ -0,0 +1,658 @@
package org.signal.core.util
import android.content.ContentValues
import android.database.Cursor
import android.database.sqlite.SQLiteDatabase
import androidx.core.content.contentValuesOf
import androidx.sqlite.db.SupportSQLiteDatabase
import androidx.sqlite.db.SupportSQLiteQueryBuilder
import androidx.sqlite.db.SupportSQLiteStatement
import org.signal.core.util.SqlUtil.ForeignKeyViolation
import org.signal.core.util.logging.Log
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds
private val TAG = "SQLiteDatabaseExtensions"
/**
* Begins a transaction on the `this` database, runs the provided [block] providing the `this` value as it's argument
* within the transaction, and then ends the transaction successfully.
*
* @return The value returned by [block] if any
*/
inline fun <T : SupportSQLiteDatabase, R> T.withinTransaction(block: (T) -> R): R {
beginTransaction()
try {
val toReturn = block(this)
if (inTransaction()) {
setTransactionSuccessful()
}
return toReturn
} finally {
if (inTransaction()) {
endTransaction()
}
}
}
fun SupportSQLiteDatabase.getTableRowCount(table: String): Int {
return this.query("SELECT COUNT(*) FROM $table").use {
if (it.moveToFirst()) {
it.getInt(0)
} else {
0
}
}
}
fun SupportSQLiteDatabase.getAllTables(): List<String> {
return SqlUtil.getAllTables(this)
}
/**
* Returns a list of objects that represent the table definitions in the database. Basically the table name and then the SQL that was used to create it.
*/
fun SupportSQLiteDatabase.getAllTableDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL AND name != ?", "table", "sqlite_sequence")
.run()
.readToList { cursor ->
CreateStatement(
name = cursor.requireNonNullString("name"),
statement = cursor.requireNonNullString("sql").replace(" ", "")
)
}
.filterNot { it.name.startsWith("sqlite_stat") }
.sortedBy { it.name }
}
/**
* Returns a list of objects that represent the index definitions in the database. Basically the index name and then the SQL that was used to create it.
*/
fun SupportSQLiteDatabase.getAllIndexDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL", "index")
.run()
.readToList { cursor ->
CreateStatement(
name = cursor.requireNonNullString("name"),
statement = cursor.requireNonNullString("sql")
)
}
.sortedBy { it.name }
}
/**
* Retrieves the names of all triggers, sorted alphabetically.
*/
fun SupportSQLiteDatabase.getAllTriggerDefinitions(): List<CreateStatement> {
return this
.select("name", "sql")
.from("sqlite_schema")
.where("type = ? AND sql NOT NULL", "trigger")
.run()
.readToList {
CreateStatement(
name = it.requireNonNullString("name"),
statement = it.requireNonNullString("sql")
)
}
.sortedBy { it.name }
}
fun SupportSQLiteDatabase.getForeignKeys(): List<ForeignKeyConstraint> {
return SqlUtil.getAllTables(this)
.map { table ->
this.query("PRAGMA foreign_key_list($table)").readToList { cursor ->
ForeignKeyConstraint(
table = table,
column = cursor.requireNonNullString("from"),
dependsOnTable = cursor.requireNonNullString("table"),
dependsOnColumn = cursor.requireNonNullString("to"),
onDelete = cursor.requireString("on_delete") ?: "NOTHING"
)
}
}
.flatten()
}
fun SupportSQLiteDatabase.areForeignKeyConstraintsEnabled(): Boolean {
return this.query("PRAGMA foreign_keys", arrayOf()).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) != 0
}
}
/**
* Provides a list of all foreign key violations present.
* If a [targetTable] is specified, results will be limited to that table specifically.
* Otherwise, the check will be performed across all tables.
*/
@JvmOverloads
fun SupportSQLiteDatabase.getForeignKeyViolations(targetTable: String? = null): List<ForeignKeyViolation> {
return SqlUtil.getForeignKeyViolations(this, targetTable)
}
/**
* For tables that have an autoincrementing primary key, this will reset the key to start back at 1.
* IMPORTANT: This is quite dangerous! Only do this if you're effectively resetting the entire database.
*/
fun SupportSQLiteDatabase.resetAutoIncrementValue(targetTable: String) {
SqlUtil.resetAutoIncrementValue(this, targetTable)
}
/**
* Does a full WAL checkpoint (TRUNCATE mode, where the log is for sure flushed and the log is zero'd out).
* Will try up to [maxAttempts] times. Can technically fail if the database is too active and the checkpoint
* can't complete in a reasonable amount of time.
*
* See: https://www.sqlite.org/pragma.html#pragma_wal_checkpoint
*/
fun SupportSQLiteDatabase.fullWalCheckpoint(maxAttempts: Int = 3): Boolean {
var attempts = 0
while (attempts < maxAttempts) {
if (this.walCheckpoint()) {
return true
}
attempts++
}
return false
}
private fun SupportSQLiteDatabase.walCheckpoint(): Boolean {
return this.query("PRAGMA wal_checkpoint(TRUNCATE)").use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 0
}
}
fun SupportSQLiteDatabase.getIndexes(): List<Index> {
return this.query("SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name ASC").readToList { cursor ->
val indexName = cursor.requireNonNullString("name")
Index(
name = indexName,
table = cursor.requireNonNullString("tbl_name"),
columns = this.query("PRAGMA index_info($indexName)").readToList { it.requireNonNullString("name") }
)
}
}
fun SupportSQLiteDatabase.forceForeignKeyConstraintsEnabled(enabled: Boolean, timeout: Duration = 10.seconds) {
val startTime = System.currentTimeMillis()
while (true) {
try {
this.setForeignKeyConstraintsEnabled(enabled)
break
} catch (e: IllegalStateException) {
if (System.currentTimeMillis() - startTime > timeout.inWholeMilliseconds) {
throw IllegalStateException("Failed to force foreign keys to '$enabled' within the timeout of $timeout", e)
}
Log.w(TAG, "Failed to set foreign keys because we're in a transaction. Waiting 100ms then trying again.")
ThreadUtil.sleep(100)
}
}
}
/**
* Checks if a row exists that matches the query.
*/
fun SupportSQLiteDatabase.exists(table: String): ExistsBuilderPart1 {
return ExistsBuilderPart1(this, table)
}
/**
* Begins a SELECT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.select(vararg columns: String): SelectBuilderPart1 {
return SelectBuilderPart1(this, arrayOf(*columns))
}
/**
* Begins a COUNT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.count(): SelectBuilderPart1 {
return SelectBuilderPart1(this, SqlUtil.COUNT)
}
/**
* Begins an UPDATE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to update all items in the table, use [updateAll].
*/
fun SupportSQLiteDatabase.update(tableName: String): UpdateBuilderPart1 {
return UpdateBuilderPart1(this, tableName)
}
fun SupportSQLiteDatabase.updateAll(tableName: String): UpdateAllBuilderPart1 {
return UpdateAllBuilderPart1(this, tableName)
}
/**
* Begins a DELETE statement with a helpful builder pattern.
* Requires a WHERE clause as a way of mitigating mistakes. If you'd like to delete all items in the table, use [deleteAll].
*/
fun SupportSQLiteDatabase.delete(tableName: String): DeleteBuilderPart1 {
return DeleteBuilderPart1(this, tableName)
}
/**
* Deletes all data in the table.
*/
fun SupportSQLiteDatabase.deleteAll(tableName: String): Int {
return this.delete(tableName, null, arrayOfNulls<String>(0))
}
/**
* Begins an INSERT statement with a helpful builder pattern.
*/
fun SupportSQLiteDatabase.insertInto(tableName: String): InsertBuilderPart1 {
return InsertBuilderPart1(this, tableName)
}
/**
* Bind an arbitrary value to an index. It will handle calling the correct bind method based on the class type.
* @param index The index you want to bind to. Important: Indexes start at 1, not 0.
*/
fun SupportSQLiteStatement.bindValue(index: Int, value: Any?) {
when (value) {
null -> this.bindNull(index)
is DatabaseId -> this.bindString(index, value.serialize())
is Boolean -> this.bindLong(index, value.toInt().toLong())
is ByteArray -> this.bindBlob(index, value)
is Number -> {
if (value.toLong() == value || value.toInt() == value || value.toShort() == value || value.toByte() == value) {
this.bindLong(index, value.toLong())
} else {
this.bindDouble(index, value.toDouble())
}
}
else -> this.bindString(index, value.toString())
}
}
class SelectBuilderPart1(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>
) {
fun from(tableName: String): SelectBuilderPart2 {
return SelectBuilderPart2(db, columns, tableName)
}
}
class SelectBuilderPart2(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): SelectBuilderPart3 {
return SelectBuilderPart3(db, columns, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): SelectBuilderPart3 {
return SelectBuilderPart3(db, columns, tableName, where, whereArgs)
}
fun orderBy(orderBy: String): SelectBuilderPart4a {
return SelectBuilderPart4a(db, columns, tableName, "", arrayOf(), orderBy)
}
fun limit(limit: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, "", arrayOf(), limit.toString())
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.create()
)
}
}
class SelectBuilderPart3(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun orderBy(orderBy: String): SelectBuilderPart4a {
return SelectBuilderPart4a(db, columns, tableName, where, whereArgs, orderBy)
}
fun limit(limit: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, limit.toString())
}
fun limit(limit: String): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, limit)
}
fun limit(limit: Int, offset: Int): SelectBuilderPart4b {
return SelectBuilderPart4b(db, columns, tableName, where, whereArgs, "$offset,$limit")
}
fun groupBy(groupBy: String): SelectBuilderPart4c {
return SelectBuilderPart4c(db, columns, tableName, where, whereArgs, groupBy)
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.create()
)
}
}
class SelectBuilderPart4a(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val orderBy: String
) {
fun limit(limit: Int): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit.toString())
}
fun limit(limit: String): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit)
}
fun limit(limit: Int, offset: Int): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, "$offset,$limit")
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.orderBy(orderBy)
.create()
)
}
}
class SelectBuilderPart4b(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val limit: String
) {
fun orderBy(orderBy: String): SelectBuilderPart5 {
return SelectBuilderPart5(db, columns, tableName, where, whereArgs, orderBy, limit)
}
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.limit(limit)
.create()
)
}
}
class SelectBuilderPart4c(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val groupBy: String
) {
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.groupBy(groupBy)
.create()
)
}
}
class SelectBuilderPart5(
private val db: SupportSQLiteDatabase,
private val columns: Array<String>,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>,
private val orderBy: String,
private val limit: String
) {
fun run(): Cursor {
return db.query(
SupportSQLiteQueryBuilder
.builder(tableName)
.columns(columns)
.selection(where, whereArgs)
.orderBy(orderBy)
.limit(limit)
.create()
)
}
}
class UpdateBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): UpdateBuilderPart2 {
return UpdateBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): UpdateBuilderPart2 {
return UpdateBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class UpdateBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
fun where(where: String, vararg whereArgs: Any): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, whereArgs.toArgs())
}
fun where(where: String, whereArgs: Array<String>): UpdateBuilderPart3 {
require(where.isNotBlank())
return UpdateBuilderPart3(db, tableName, values, where, whereArgs)
}
}
class UpdateBuilderPart3(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues,
private val where: String,
private val whereArgs: Array<out Any?>
) {
@JvmOverloads
fun run(): Int {
val query = StringBuilder("UPDATE $tableName SET ")
val contentValuesKeys = values.keySet()
for ((index, column) in contentValuesKeys.withIndex()) {
query.append(column).append(" = ?")
if (index < contentValuesKeys.size - 1) {
query.append(", ")
}
}
query.append(" WHERE ").append(where)
val statement = db.compileStatement(query.toString())
var bindIndex = 1
for (key in contentValuesKeys) {
statement.bindValue(bindIndex, values.get(key))
bindIndex++
}
for (arg in whereArgs) {
statement.bindValue(bindIndex, arg)
bindIndex++
}
return statement.use { it.executeUpdateDelete() }
}
}
class UpdateAllBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): UpdateAllBuilderPart2 {
return UpdateAllBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class UpdateAllBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
@JvmOverloads
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_NONE): Int {
return db.update(tableName, conflictStrategy, values, null, emptyArray<String>())
}
}
class DeleteBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): DeleteBuilderPart2 {
require(where.isNotBlank())
return DeleteBuilderPart2(db, tableName, where, whereArgs)
}
}
class DeleteBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun run(): Int {
return db.delete(tableName, where, whereArgs)
}
}
class ExistsBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun where(where: String, vararg whereArgs: Any): ExistsBuilderPart2 {
return ExistsBuilderPart2(db, tableName, where, SqlUtil.buildArgs(*whereArgs))
}
fun where(where: String, whereArgs: Array<String>): ExistsBuilderPart2 {
return ExistsBuilderPart2(db, tableName, where, whereArgs)
}
fun run(): Boolean {
return db.query("SELECT EXISTS(SELECT 1 FROM $tableName)", arrayOf()).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 1
}
}
}
class ExistsBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val where: String,
private val whereArgs: Array<String>
) {
fun run(): Boolean {
return db.query("SELECT EXISTS(SELECT 1 FROM $tableName WHERE $where)", SqlUtil.buildArgs(*whereArgs)).use { cursor ->
cursor.moveToFirst() && cursor.getInt(0) == 1
}
}
}
class InsertBuilderPart1(
private val db: SupportSQLiteDatabase,
private val tableName: String
) {
fun values(values: ContentValues): InsertBuilderPart2 {
return InsertBuilderPart2(db, tableName, values)
}
fun values(vararg values: Pair<String, Any?>): InsertBuilderPart2 {
return InsertBuilderPart2(db, tableName, contentValuesOf(*values))
}
}
class InsertBuilderPart2(
private val db: SupportSQLiteDatabase,
private val tableName: String,
private val values: ContentValues
) {
fun run(conflictStrategy: Int = SQLiteDatabase.CONFLICT_IGNORE): Long {
return db.insert(tableName, conflictStrategy, values)
}
}
/**
* Helper function to massage passed-in arguments into a better form to give to the database.
*/
private fun Array<out Any?>.toArgs(): Array<Any?> {
return this
.map {
when (it) {
is DatabaseId -> it.serialize()
else -> it
}
}
.toTypedArray()
}
data class ForeignKeyConstraint(
val table: String,
val column: String,
val dependsOnTable: String,
val dependsOnColumn: String,
val onDelete: String
)
data class Index(
val name: String,
val table: String,
val columns: List<String>
)
data class CreateStatement(
val name: String,
val statement: String
)

View file

@ -0,0 +1,41 @@
package org.signal.core.util
import android.content.ContentValues
import android.database.Cursor
/**
* Generalized serializer for finer control
*/
interface BaseSerializer<Data, Input, Output> {
fun serialize(data: Data): Output
fun deserialize(input: Input): Data
}
/**
* Generic serialization interface for use with database and store operations.
*/
interface Serializer<T, R> : BaseSerializer<T, R, R>
/**
* Serializer specifically for working with SQLite
*/
interface DatabaseSerializer<Data> : BaseSerializer<Data, Cursor, ContentValues>
interface StringSerializer<T> : Serializer<T, String>
interface IntSerializer<T> : Serializer<T, Int>
interface LongSerializer<T> : Serializer<T, Long>
interface ByteSerializer<T> : Serializer<T, ByteArray>
object StringStringSerializer : StringSerializer<String?> {
override fun serialize(data: String?): String {
return data ?: ""
}
override fun deserialize(data: String): String {
return data
}
}

View file

@ -0,0 +1,238 @@
// Copyright 2010 Square, Inc.
// Modified 2020 Signal
package org.signal.core.util;
import android.hardware.Sensor;
import android.hardware.SensorEvent;
import android.hardware.SensorEventListener;
import android.hardware.SensorManager;
/**
* Detects phone shaking. If more than 75% of the samples taken in the past 0.5s are
* accelerating, the device is a) shaking, or b) free falling 1.84m (h =
* 1/2*g*t^2*3/4).
*
* @author Bob Lee (bob@squareup.com)
* @author Eric Burke (eric@squareup.com)
*/
public class ShakeDetector implements SensorEventListener {
private static final int SHAKE_THRESHOLD = 13;
/** Listens for shakes. */
public interface Listener {
/** Called on the main thread when the device is shaken. */
void onShakeDetected();
}
private final SampleQueue queue = new SampleQueue();
private final Listener listener;
private SensorManager sensorManager;
private Sensor accelerometer;
public ShakeDetector(Listener listener) {
this.listener = listener;
}
/**
* Starts listening for shakes on devices with appropriate hardware.
*
* @return true if the device supports shake detection.
*/
public boolean start(SensorManager sensorManager) {
if (accelerometer != null) {
return true;
}
accelerometer = sensorManager.getDefaultSensor(Sensor.TYPE_ACCELEROMETER);
if (accelerometer != null) {
this.sensorManager = sensorManager;
sensorManager.registerListener(this, accelerometer, SensorManager.SENSOR_DELAY_NORMAL);
}
return accelerometer != null;
}
/**
* Stops listening. Safe to call when already stopped. Ignored on devices without appropriate
* hardware.
*/
public void stop() {
if (accelerometer != null) {
queue.clear();
sensorManager.unregisterListener(this, accelerometer);
sensorManager = null;
accelerometer = null;
}
}
@Override
public void onSensorChanged(SensorEvent event) {
boolean accelerating = isAccelerating(event);
long timestamp = event.timestamp;
queue.add(timestamp, accelerating);
if (queue.isShaking()) {
queue.clear();
listener.onShakeDetected();
}
}
/** Returns true if the device is currently accelerating. */
private boolean isAccelerating(SensorEvent event) {
float ax = event.values[0];
float ay = event.values[1];
float az = event.values[2];
// Instead of comparing magnitude to ACCELERATION_THRESHOLD,
// compare their squares. This is equivalent and doesn't need the
// actual magnitude, which would be computed using (expensive) Math.sqrt().
final double magnitudeSquared = ax * ax + ay * ay + az * az;
return magnitudeSquared > SHAKE_THRESHOLD * SHAKE_THRESHOLD;
}
/** Queue of samples. Keeps a running average. */
static class SampleQueue {
/** Window size in ns. Used to compute the average. */
private static final long MAX_WINDOW_SIZE = 500000000; // 0.5s
private static final long MIN_WINDOW_SIZE = MAX_WINDOW_SIZE >> 1; // 0.25s
/**
* Ensure the queue size never falls below this size, even if the device
* fails to deliver this many events during the time window. The LG Ally
* is one such device.
*/
private static final int MIN_QUEUE_SIZE = 4;
private final SamplePool pool = new SamplePool();
private Sample oldest;
private Sample newest;
private int sampleCount;
private int acceleratingCount;
/**
* Adds a sample.
*
* @param timestamp in nanoseconds of sample
* @param accelerating true if > {@link #SHAKE_THRESHOLD}.
*/
void add(long timestamp, boolean accelerating) {
purge(timestamp - MAX_WINDOW_SIZE);
Sample added = pool.acquire();
added.timestamp = timestamp;
added.accelerating = accelerating;
added.next = null;
if (newest != null) {
newest.next = added;
}
newest = added;
if (oldest == null) {
oldest = added;
}
sampleCount++;
if (accelerating) {
acceleratingCount++;
}
}
/** Removes all samples from this queue. */
void clear() {
while (oldest != null) {
Sample removed = oldest;
oldest = removed.next;
pool.release(removed);
}
newest = null;
sampleCount = 0;
acceleratingCount = 0;
}
/** Purges samples with timestamps older than cutoff. */
void purge(long cutoff) {
while (sampleCount >= MIN_QUEUE_SIZE && oldest != null && cutoff - oldest.timestamp > 0) {
Sample removed = oldest;
if (removed.accelerating) {
acceleratingCount--;
}
sampleCount--;
oldest = removed.next;
if (oldest == null) {
newest = null;
}
pool.release(removed);
}
}
/**
* Returns true if we have enough samples and more than 3/4 of those samples
* are accelerating.
*/
boolean isShaking() {
return newest != null &&
oldest != null &&
newest.timestamp - oldest.timestamp >= MIN_WINDOW_SIZE &&
acceleratingCount >= (sampleCount >> 1) + (sampleCount >> 2);
}
}
/** An accelerometer sample. */
static class Sample {
/** Time sample was taken. */
long timestamp;
/** If acceleration > {@link #SHAKE_THRESHOLD}. */
boolean accelerating;
/** Next sample in the queue or pool. */
Sample next;
}
/** Pools samples. Avoids garbage collection. */
static class SamplePool {
private Sample head;
/** Acquires a sample from the pool. */
Sample acquire() {
Sample acquired = head;
if (acquired == null) {
acquired = new Sample();
} else {
head = acquired.next;
}
return acquired;
}
/** Returns a sample to the pool. */
void release(Sample sample) {
sample.next = head;
head = sample;
}
}
@Override
public void onAccuracyChanged(Sensor sensor, int accuracy) {
}
}

View file

@ -0,0 +1,518 @@
package org.signal.core.util
import android.content.ContentValues
import android.text.TextUtils
import androidx.annotation.VisibleForTesting
import androidx.sqlite.db.SupportSQLiteDatabase
import org.signal.core.util.logging.Log
import java.lang.Exception
import java.util.LinkedList
import java.util.Locale
import java.util.stream.Collectors
object SqlUtil {
private val TAG = Log.tag(SqlUtil::class.java)
/** The maximum number of arguments (i.e. question marks) allowed in a SQL statement. */
const val MAX_QUERY_ARGS = 999
@JvmField
val COUNT = arrayOf("COUNT(*)")
@JvmStatic
fun tableExists(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT name FROM sqlite_master WHERE type=? AND name=?", arrayOf("table", table)).use { cursor ->
return cursor != null && cursor.moveToNext()
}
}
@JvmStatic
fun getAllTables(db: SupportSQLiteDatabase): List<String> {
val tables: MutableList<String> = LinkedList()
db.query("SELECT name FROM sqlite_master WHERE type=?", arrayOf("table")).use { cursor ->
while (cursor.moveToNext()) {
tables.add(cursor.getString(0))
}
}
return tables
}
/**
* Returns the total number of changes that have been made since the creation of this database connection.
*
* IMPORTANT: Due to how connection pooling is handled in the app, the only way to have this return useful numbers is to call it within a transaction.
*/
fun getTotalChanges(db: SupportSQLiteDatabase): Long {
return db.query("SELECT total_changes()", arrayOf()).readToSingleLong()
}
@JvmStatic
fun getAllTriggers(db: SupportSQLiteDatabase): List<String> {
val tables: MutableList<String> = LinkedList()
db.query("SELECT name FROM sqlite_master WHERE type=?", arrayOf("trigger")).use { cursor ->
while (cursor.moveToNext()) {
tables.add(cursor.getString(0))
}
}
return tables
}
@JvmStatic
fun getNextAutoIncrementId(db: SupportSQLiteDatabase, table: String): Long {
db.query("SELECT * FROM sqlite_sequence WHERE name = ?", arrayOf(table)).use { cursor ->
if (cursor.moveToFirst()) {
val current = cursor.requireLong("seq")
return current + 1
} else if (db.query("SELECT COUNT(*) FROM $table").readToSingleLong(defaultValue = 0) == 0L) {
Log.w(TAG, "No entries exist in $table. Returning 1.")
return 1
} else if (columnExists(db, table, "_id")) {
Log.w(TAG, "There are entries in $table, but we couldn't get the auto-incrementing id? Using the max _id in the table.")
val current = db.query("SELECT MAX(_id) FROM $table").readToSingleLong(defaultValue = 0)
return current + 1
} else {
Log.w(TAG, "No autoincrement _id, non-empty table, no _id column!")
throw IllegalArgumentException("Table must have an auto-incrementing primary key!")
}
}
}
/**
* Given a table, this will return a set of tables that it has a foreign key dependency on.
*/
@JvmStatic
fun getForeignKeyDependencies(db: SupportSQLiteDatabase, table: String): Set<String> {
return db.query("PRAGMA foreign_key_list($table)")
.readToSet { cursor ->
cursor.requireNonNullString("table")
}
}
/**
* Provides a list of all foreign key violations present.
* If a [targetTable] is specified, results will be limited to that table specifically.
* Otherwise, the check will be performed across all tables.
*/
@JvmStatic
@JvmOverloads
fun getForeignKeyViolations(db: SupportSQLiteDatabase, targetTable: String? = null): List<ForeignKeyViolation> {
val tableString = if (targetTable != null) "($targetTable)" else ""
return db.query("PRAGMA foreign_key_check$tableString").readToList { cursor ->
val table = cursor.requireNonNullString("table")
ForeignKeyViolation(
table = table,
violatingRowId = cursor.requireLongOrNull("rowid"),
dependsOnTable = cursor.requireNonNullString("parent"),
column = getForeignKeyViolationColumn(db, table, cursor.requireLong("fkid"))
)
}
}
/**
* For tables that have an autoincrementing primary key, this will reset the key to start back at 1.
* IMPORTANT: This is quite dangerous! Only do this if you're effectively resetting the entire database.
*/
@JvmStatic
fun resetAutoIncrementValue(db: SupportSQLiteDatabase, targetTable: String) {
db.execSQL("DELETE FROM sqlite_sequence WHERE name=?", arrayOf(targetTable))
}
@JvmStatic
fun isEmpty(db: SupportSQLiteDatabase, table: String): Boolean {
db.query("SELECT COUNT(*) FROM $table", arrayOf()).use { cursor ->
return if (cursor.moveToFirst()) {
cursor.getInt(0) == 0
} else {
true
}
}
}
@JvmStatic
fun columnExists(db: SupportSQLiteDatabase, table: String, column: String): Boolean {
db.query("PRAGMA table_info($table)", arrayOf()).use { cursor ->
val nameColumnIndex = cursor.getColumnIndexOrThrow("name")
while (cursor.moveToNext()) {
val name = cursor.getString(nameColumnIndex)
if (name == column) {
return true
}
}
}
return false
}
@JvmStatic
fun buildArgs(vararg objects: Any?): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic
fun buildArgs(objects: Collection<Any?>): Array<String> {
return objects.map {
when (it) {
null -> throw NullPointerException("Cannot have null arg!")
is DatabaseId -> it.serialize()
else -> it.toString()
}
}.toTypedArray()
}
@JvmStatic
fun buildArgs(argument: Long): Array<String> {
return arrayOf(argument.toString())
}
/**
* Builds a case-insensitive GLOB pattern for fuzzy text queries. Works with all unicode
* characters.
*
* Ex:
* cat -> [cC][aA][tT]
*/
@JvmStatic
fun buildCaseInsensitiveGlobPattern(query: String): String {
if (TextUtils.isEmpty(query)) {
return "*"
}
val pattern = StringBuilder()
var i = 0
val len = query.codePointCount(0, query.length)
while (i < len) {
val point = StringUtil.codePointToString(query.codePointAt(i))
pattern.append("[")
pattern.append(point.lowercase(Locale.getDefault()))
pattern.append(point.uppercase(Locale.getDefault()))
pattern.append(getAccentuatedCharRegex(point.lowercase(Locale.getDefault())))
pattern.append("]")
i++
}
return "*$pattern*"
}
private fun getAccentuatedCharRegex(query: String): String {
return when (query) {
"a" -> "À-Åà-åĀ-ąǍǎǞ-ǡǺ-ǻȀ-ȃȦȧȺɐ-ɒḀḁẚẠ-ặ"
"b" -> "ßƀ-ƅɃɓḂ-ḇ"
"c" -> "çÇĆ-čƆ-ƈȻȼɔḈḉ"
"d" -> "ÐðĎ-đƉ-ƍȡɖɗḊ-ḓ"
"e" -> "È-Ëè-ëĒ-ěƎ-ƐǝȄ-ȇȨȩɆɇɘ-ɞḔ-ḝẸ-ệ"
"f" -> "ƑƒḞḟ"
"g" -> "Ĝ-ģƓǤ-ǧǴǵḠḡ"
"h" -> "Ĥ-ħƕǶȞȟḢ-ḫẖ"
"i" -> "Ì-Ïì-ïĨ-ıƖƗǏǐȈ-ȋɨɪḬ-ḯỈ-ị"
"j" -> "ĴĵǰȷɈɉɟ"
"k" -> "Ķ-ĸƘƙǨǩḰ-ḵ"
"l" -> "Ĺ-łƚȴȽɫ-ɭḶ-ḽ"
"m" -> "Ɯɯ-ɱḾ-ṃ"
"n" -> "ÑñŃ-ŋƝƞǸǹȠȵɲ-ɴṄ-ṋ"
"o" -> "Ò-ÖØò-öøŌ-őƟ-ơǑǒǪ-ǭǾǿȌ-ȏȪ-ȱṌ-ṓỌ-ợ"
"p" -> "ƤƥṔ-ṗ"
"q" -> ""
"r" -> "Ŕ-řƦȐ-ȓɌɍṘ-ṟ"
"s" -> "Ś-šƧƨȘșȿṠ-ṩ"
"t" -> "Ţ-ŧƫ-ƮȚțȾṪ-ṱẗ"
"u" -> "Ù-Üù-üŨ-ųƯ-ƱǓ-ǜȔ-ȗɄṲ-ṻỤ-ự"
"v" -> "ƲɅṼ-ṿ"
"w" -> "ŴŵẀ-ẉẘ"
"x" -> "Ẋ-ẍ"
"y" -> "ÝýÿŶ-ŸƔƳƴȲȳɎɏẎẏỲ-ỹỾỿẙ"
"z" -> "Ź-žƵƶɀẐ-ẕ"
"α" -> "\u0386\u0391\u03AC\u03B1\u1F00-\u1F0F\u1F70\u1F71\u1F80-\u1F8F\u1FB0-\u1FB4\u1FB6-\u1FBC"
"ε" -> "\u0388\u0395\u03AD\u03B5\u1F10-\u1F15\u1F18-\u1F1D\u1F72\u1F73\u1FC8\u1FC9"
"η" -> "\u0389\u0397\u03AE\u03B7\u1F20-\u1F2F\u1F74\u1F75\u1F90-\u1F9F\u1F20-\u1F2F\u1F74\u1F75\u1F90-\u1F9F\u1fc2\u1fc3\u1fc4\u1fc6\u1FC7\u1FCA\u1FCB\u1FCC"
"ι" -> "\u038A\u0390\u0399\u03AA\u03AF\u03B9\u03CA\u1F30-\u1F3F\u1F76\u1F77\u1FD0-\u1FD3\u1FD6-\u1FDB"
"ο" -> "\u038C\u039F\u03BF\u03CC\u1F40-\u1F45\u1F48-\u1F4D\u1F78\u1F79\u1FF8\u1FF9"
"σ" -> "\u03A3\u03C2\u03C3"
"ς" -> "\u03A3\u03C2\u03C3"
"υ" -> "\u038E\u03A5\u03AB\u03C5\u03CB\u03CD\u1F50-\u1F57\u1F59\u1F5B\u1F5D\u1F5F\u1F7A\u1F7B\u1FE0-\u1FE3\u1FE6-\u1FEB"
"ω" -> "\u038F\u03A9\u03C9\u03CE\u1F60-\u1F6F\u1F7C\u1F7D\u1FA0-\u1FAF\u1FF2-\u1FF4\u1FF6\u1FF7\u1FFA-\u1FFC"
else -> ""
}
}
/**
* Returns an updated query and args pairing that will only update rows that would *actually*
* change. In other words, if [SupportSQLiteDatabase.update]
* returns > 0, then you know something *actually* changed.
*/
@JvmStatic
fun buildTrueUpdateQuery(
selection: String,
args: Array<String>,
contentValues: ContentValues
): Query {
val qualifier = StringBuilder()
val valueSet = contentValues.valueSet()
val fullArgs: MutableList<String> = ArrayList(args.size + valueSet.size)
fullArgs.addAll(args)
var i = 0
for ((key, value) in valueSet) {
if (value != null) {
if (value is ByteArray) {
qualifier.append("hex(").append(key).append(") != ? OR ").append(key).append(" IS NULL")
fullArgs.add(Hex.toStringCondensed(value).uppercase(Locale.US))
} else {
qualifier.append(key).append(" != ? OR ").append(key).append(" IS NULL")
fullArgs.add(value.toString())
}
} else {
qualifier.append(key).append(" NOT NULL")
}
if (i != valueSet.size - 1) {
qualifier.append(" OR ")
}
i++
}
return Query("($selection) AND ($qualifier)", fullArgs.toTypedArray())
}
/**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
* Handles breaking it
*/
@JvmOverloads
@JvmStatic
fun buildCollectionQuery(
column: String,
values: Collection<Any?>,
prefix: String = "",
maxSize: Int = MAX_QUERY_ARGS,
collectionOperator: CollectionOperator = CollectionOperator.IN
): List<Query> {
return if (values.isEmpty()) {
emptyList()
} else {
values
.chunked(maxSize)
.map { batch -> buildSingleCollectionQuery(column, batch, prefix, collectionOperator) }
}
}
/**
* A convenient way of making queries that are _equivalent_ to `WHERE [column] IN (?, ?, ..., ?)`
* Under the hood, it uses JSON1 functions which can both be surprisingly faster than normal (?, ?, ?) lists, as well as removes the [MAX_QUERY_ARGS] limit.
* This means chunking isn't necessary for any practical collection length.
*/
@JvmStatic
fun buildFastCollectionQuery(
column: String,
values: Collection<Any?>
): Query {
require(!values.isEmpty()) { "Must have values!" }
return Query("$column IN (SELECT e.value FROM json_each(?) e)", arrayOf(jsonEncode(buildArgs(values))))
}
/**
* A convenient way of making queries in the form: WHERE [column] IN (?, ?, ..., ?)
*
* Important: Should only be used if you know the number of values is < 1000. Otherwise you risk creating a SQL statement this is too large.
* Prefer [buildCollectionQuery] when possible.
*/
@JvmOverloads
@JvmStatic
fun buildSingleCollectionQuery(
column: String,
values: Collection<Any?>,
prefix: String = "",
collectionOperator: CollectionOperator = CollectionOperator.IN
): Query {
require(!values.isEmpty()) { "Must have values!" }
val query = StringBuilder()
val args = arrayOfNulls<Any>(values.size)
var i = 0
for (value in values) {
query.append("?")
args[i] = value
if (i != values.size - 1) {
query.append(", ")
}
i++
}
return Query("$prefix $column ${collectionOperator.sql} ($query)".trim(), buildArgs(*args))
}
@JvmStatic
fun buildCustomCollectionQuery(query: String, argList: List<Array<String>>): List<Query> {
return buildCustomCollectionQuery(query, argList, MAX_QUERY_ARGS)
}
@JvmStatic
@VisibleForTesting
fun buildCustomCollectionQuery(query: String, argList: List<Array<String>>, maxQueryArgs: Int): List<Query> {
val batchSize: Int = maxQueryArgs / argList[0].size
return ListUtil.chunk(argList, batchSize)
.stream()
.map { argBatch -> buildSingleCustomCollectionQuery(query, argBatch) }
.collect(Collectors.toList())
}
private fun buildSingleCustomCollectionQuery(query: String, argList: List<Array<String>>): Query {
val outputQuery = StringBuilder()
val outputArgs: MutableList<String> = mutableListOf()
var i = 0
val len = argList.size
while (i < len) {
outputQuery.append("(").append(query).append(")")
if (i < len - 1) {
outputQuery.append(" OR ")
}
val args = argList[i]
for (arg in args) {
outputArgs += arg
}
i++
}
return Query(outputQuery.toString(), outputArgs.toTypedArray())
}
@JvmStatic
fun buildQuery(where: String, vararg args: Any): Query {
return Query(where, buildArgs(*args))
}
@JvmStatic
fun appendArg(args: Array<String>, addition: String): Array<String> {
return args.toMutableList().apply {
add(addition)
}.toTypedArray()
}
@JvmStatic
fun appendArgs(args: Array<String>, vararg objects: Any?): Array<String> {
return args + buildArgs(objects)
}
@JvmStatic
fun buildBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, onConflict: String? = null): List<Query> {
return buildBulkInsert(tableName, columns, contentValues, MAX_QUERY_ARGS)
}
@JvmStatic
@VisibleForTesting
fun buildBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, maxQueryArgs: Int, onConflict: String? = null): List<Query> {
val batchSize = maxQueryArgs / columns.size
return contentValues
.chunked(batchSize)
.map { batch: List<ContentValues> -> buildSingleBulkInsert(tableName, columns, batch) }
.toList()
}
fun buildSingleBulkInsert(tableName: String, columns: Array<String>, contentValues: List<ContentValues>, onConflict: String? = null): Query {
val conflictString = onConflict?.let { " OR $onConflict" } ?: ""
val builder = StringBuilder()
builder.append("INSERT$conflictString INTO ").append(tableName).append(" (")
val columnString = columns.joinToString(separator = ", ")
builder.append(columnString)
builder.append(") VALUES ")
val placeholders = contentValues
.map { values ->
columns
.map { column ->
if (values[column] != null) {
if (values[column] is ByteArray) {
"X'${Hex.toStringCondensed(values[column] as ByteArray).uppercase()}'"
} else {
"?"
}
} else {
"null"
}
}
.joinToString(separator = ", ", prefix = "(", postfix = ")")
}
.joinToString(separator = ", ")
builder.append(placeholders)
val query = builder.toString()
val args: MutableList<String> = mutableListOf()
for (values in contentValues) {
for (column in columns) {
val value = values[column]
if (value != null && value !is ByteArray) {
args += value.toString()
}
}
}
return Query(query, args.toTypedArray())
}
/** Helper that gets the specific column for a foreign key violation */
private fun getForeignKeyViolationColumn(db: SupportSQLiteDatabase, table: String, id: Long): String? {
try {
db.query("PRAGMA foreign_key_list($table)").forEach { cursor ->
if (cursor.requireLong("id") == id) {
return cursor.requireString("from")
}
}
} catch (e: Exception) {
Log.w(TAG, "Failed to find violation details for id: $id")
}
return null
}
/** Simple encoding of a string array as a json array */
private fun jsonEncode(strings: Array<String>): String {
return strings.joinToString(prefix = "[", postfix = "]", separator = ",") { "\"$it\"" }
}
class Query(val where: String, val whereArgs: Array<String>) {
infix fun and(other: Query): Query {
return if (where.isNotEmpty() && other.where.isNotEmpty()) {
Query("($where) AND (${other.where})", whereArgs + other.whereArgs)
} else if (where.isNotEmpty()) {
this
} else {
other
}
}
}
data class ForeignKeyViolation(
/** The table that declared the REFERENCES clause. */
val table: String,
/** The rowId of the message in [table] that violates the constraint. Will not be present if the table has now rowId. */
val violatingRowId: Long?,
/** The table that [table] has a dependency on. */
val dependsOnTable: String,
/** The column from [table] that has the constraint. A separate query needs to be made to get this, so it's best-effor. */
val column: String?
)
enum class CollectionOperator(val sql: String) {
IN("IN"),
NOT_IN("NOT IN")
}
}

View file

@ -0,0 +1,259 @@
package org.signal.core.util
import android.text.SpannableStringBuilder
import okio.ByteString
import okio.ByteString.Companion.toByteString
import okio.utf8Size
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.nio.charset.StandardCharsets
object StringUtil {
private val WHITESPACE: Set<Char> = setOf(
'\u200E', // left-to-right mark
'\u200F', // right-to-left mark
'\u2007', // figure space
'\u200B', // zero-width space
'\u2800' // braille blank
)
/**
* Trims a name string to fit into the byte length requirement.
*
*
* This method treats a surrogate pair and a grapheme cluster a single character
* See examples in tests defined in StringUtilText_trimToFit.
*/
@JvmStatic
fun trimToFit(name: String?, maxByteLength: Int): String {
if (name.isNullOrEmpty()) {
return ""
}
if (name.utf8Size() <= maxByteLength) {
return name
}
try {
ByteArrayOutputStream().use { stream ->
for (graphemeCharacter in CharacterIterable(name)) {
val bytes = graphemeCharacter.toByteArray(StandardCharsets.UTF_8)
if (stream.size() + bytes.size <= maxByteLength) {
stream.write(bytes)
} else {
break
}
}
return stream.toString()
}
} catch (e: IOException) {
throw AssertionError(e)
}
}
/**
* @return A charsequence with no leading or trailing whitespace. Only creates a new charsequence
* if it has to.
*/
@JvmStatic
fun trim(charSequence: CharSequence): CharSequence {
if (charSequence.isEmpty()) {
return charSequence
}
var start = 0
var end = charSequence.length - 1
while (start < charSequence.length && Character.isWhitespace(charSequence[start])) {
start++
}
while (end >= 0 && end > start && Character.isWhitespace(charSequence[end])) {
end--
}
return if (start > 0 || end < charSequence.length - 1) {
charSequence.subSequence(start, end + 1)
} else {
charSequence
}
}
/**
* @return True if the string is empty, or if it contains nothing but whitespace characters.
* Accounts for various unicode whitespace characters.
*/
@JvmStatic
fun isVisuallyEmpty(value: String?): Boolean {
if (value.isNullOrEmpty()) {
return true
}
return indexOfFirstNonEmptyChar(value) == -1
}
/**
* @return String without any leading or trailing whitespace.
* Accounts for various unicode whitespace characters.
*/
@JvmStatic
fun trimToVisualBounds(value: String): String {
val start = indexOfFirstNonEmptyChar(value)
if (start == -1) {
return ""
}
val end = indexOfLastNonEmptyChar(value)
return value.substring(start, end + 1)
}
private fun indexOfFirstNonEmptyChar(value: String): Int {
val length = value.length
for (i in 0 until length) {
if (!isVisuallyEmpty(value[i])) {
return i
}
}
return -1
}
private fun indexOfLastNonEmptyChar(value: String): Int {
for (i in value.length - 1 downTo 0) {
if (!isVisuallyEmpty(value[i])) {
return i
}
}
return -1
}
/**
* @return True if the character is invisible or whitespace. Accounts for various unicode
* whitespace characters.
*/
fun isVisuallyEmpty(c: Char): Boolean {
return Character.isWhitespace(c) || WHITESPACE.contains(c)
}
/**
* @return A string representation of the provided unicode code point.
*/
fun codePointToString(codePoint: Int): String {
return String(Character.toChars(codePoint))
}
/**
* @return True if the text is null or has a length of 0, otherwise false.
*/
@JvmStatic
fun isEmpty(text: String?): Boolean {
return text.isNullOrEmpty()
}
/**
* Trims a [CharSequence] of starting and trailing whitespace. Behavior matches
* [String.trim] to preserve expectations around results.
*/
@JvmStatic
fun trimSequence(text: CharSequence): CharSequence {
var length = text.length
var startIndex = 0
while ((startIndex < length) && (text[startIndex] <= ' ')) {
startIndex++
}
while ((startIndex < length) && (text[length - 1] <= ' ')) {
length--
}
return if ((startIndex > 0 || length < text.length)) text.subSequence(startIndex, length) else text
}
/**
* If the {@param text} exceeds the {@param maxChars} it is trimmed in the middle so that the result is exactly {@param maxChars} long including an added
* ellipsis character.
*
*
* Otherwise the string is returned untouched.
*
*
* When {@param maxChars} is even, one more character is kept from the end of the string than the start.
*/
@JvmStatic
fun abbreviateInMiddle(text: CharSequence?, maxChars: Int): CharSequence? {
if (text == null || text.length <= maxChars) {
return text
}
val start = (maxChars - 1) / 2
val end = (maxChars - 1) - start
return text.subSequence(0, start).toString() + "" + text.subSequence(text.length - end, text.length)
}
/**
* @return The number of graphemes in the provided string.
*/
@JvmStatic
fun getGraphemeCount(text: CharSequence): Int {
val iterator = BreakIteratorCompat.getInstance()
iterator.setText(text)
return iterator.countBreaks()
}
@JvmStatic
fun replace(text: CharSequence, toReplace: Char, replacement: String?): CharSequence {
var updatedText: SpannableStringBuilder? = null
for (i in text.length - 1 downTo 0) {
if (text[i] == toReplace) {
if (updatedText == null) {
updatedText = SpannableStringBuilder.valueOf(text)
}
updatedText!!.replace(i, i + 1, replacement)
}
}
return updatedText ?: text
}
@JvmStatic
fun startsWith(text: CharSequence, substring: CharSequence): Boolean {
if (substring.length > text.length) {
return false
}
for (i in substring.indices) {
if (text[i] != substring[i]) {
return false
}
}
return true
}
@JvmStatic
fun endsWith(text: CharSequence, substring: CharSequence): Boolean {
if (substring.length > text.length) {
return false
}
var textIndex = text.length - 1
var substringIndex = substring.length - 1
while (substringIndex >= 0) {
if (text[textIndex] != substring[substringIndex]) {
return false
}
substringIndex--
textIndex--
}
return true
}
fun String?.toByteString(): ByteString? {
return this?.toByteArray()?.toByteString()
}
}

View file

@ -0,0 +1,47 @@
package org.signal.core.util
import androidx.sqlite.db.SupportSQLiteProgram
import androidx.sqlite.db.SupportSQLiteQuery
fun SupportSQLiteQuery.toAndroidQuery(): SqlUtil.Query {
val program = CapturingSqliteProgram(this.argCount)
this.bindTo(program)
return SqlUtil.Query(this.sql, program.args())
}
private class CapturingSqliteProgram(count: Int) : SupportSQLiteProgram {
private val args: Array<String?> = arrayOfNulls(count)
fun args(): Array<String> {
return args.filterNotNull().toTypedArray()
}
override fun close() {
}
override fun bindNull(index: Int) {
throw UnsupportedOperationException()
}
override fun bindLong(index: Int, value: Long) {
args[index - 1] = value.toString()
}
override fun bindDouble(index: Int, value: Double) {
args[index - 1] = value.toString()
}
override fun bindString(index: Int, value: String) {
args[index - 1] = value
}
override fun bindBlob(index: Int, value: ByteArray) {
throw UnsupportedOperationException()
}
override fun clearBindings() {
for (i in args.indices) {
args[i] = null
}
}
}

View file

@ -0,0 +1,115 @@
package org.signal.core.util;
import android.os.Handler;
import android.os.Looper;
import android.os.Process;
import androidx.annotation.NonNull;
import androidx.annotation.VisibleForTesting;
import java.util.concurrent.CountDownLatch;
/**
* Thread related utility functions.
*/
public final class ThreadUtil {
/**
* Default background thread priority.
*/
public static final int PRIORITY_BACKGROUND_THREAD = Process.THREAD_PRIORITY_BACKGROUND;
/**
* Important background thread priority. This is slightly lower priority than the UI thread. Use for critical work that should run as fast as
* possible, but shouldn't block the UI (e.g. message sends)
*/
public static final int PRIORITY_IMPORTANT_BACKGROUND_THREAD = Process.THREAD_PRIORITY_DEFAULT + Process.THREAD_PRIORITY_LESS_FAVORABLE;
/**
* As important as the UI thread. Use for absolutely critical UI blocking tasks/threads. For example fetching data for display in a recyclerview, or
* anything that will block UI.
*/
public static final int PRIORITY_UI_BLOCKING_THREAD = Process.THREAD_PRIORITY_DEFAULT;
private static volatile Handler handler;
@VisibleForTesting
public static volatile boolean enforceAssertions = true;
private ThreadUtil() {}
private static Handler getHandler() {
if (handler == null) {
synchronized (ThreadUtil.class) {
if (handler == null) {
handler = new Handler(Looper.getMainLooper());
}
}
}
return handler;
}
public static boolean isMainThread() {
return Looper.myLooper() == Looper.getMainLooper();
}
public static void assertMainThread() {
if (!isMainThread() && enforceAssertions) {
throw new AssertionError("Must run on main thread.");
}
}
public static void assertNotMainThread() {
if (isMainThread() && enforceAssertions) {
throw new AssertionError("Cannot run on main thread.");
}
}
public static void postToMain(final @NonNull Runnable runnable) {
getHandler().post(runnable);
}
public static void runOnMain(final @NonNull Runnable runnable) {
if (isMainThread()) runnable.run();
else getHandler().post(runnable);
}
public static void runOnMainDelayed(final @NonNull Runnable runnable, long delayMillis) {
getHandler().postDelayed(runnable, delayMillis);
}
public static void cancelRunnableOnMain(@NonNull Runnable runnable) {
getHandler().removeCallbacks(runnable);
}
public static void runOnMainSync(final @NonNull Runnable runnable) {
if (isMainThread()) {
runnable.run();
} else {
final CountDownLatch sync = new CountDownLatch(1);
runOnMain(() -> {
try {
runnable.run();
} finally {
sync.countDown();
}
});
try {
sync.await();
} catch (InterruptedException ie) {
throw new AssertionError(ie);
}
}
}
public static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException e) {
throw new AssertionError(e);
}
}
public static void interruptableSleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException ignored) { }
}
}

View file

@ -0,0 +1,23 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.content.res.ColorStateList
import android.graphics.PorterDuff
import android.graphics.PorterDuffColorFilter
import androidx.annotation.ColorInt
import androidx.appcompat.widget.Toolbar
import androidx.core.view.MenuItemCompat
import androidx.core.view.forEach
fun Toolbar.setActionItemTint(@ColorInt tint: Int) {
menu.forEach {
MenuItemCompat.setIconTintList(it, ColorStateList.valueOf(tint))
}
navigationIcon?.colorFilter = PorterDuffColorFilter(tint, PorterDuff.Mode.SRC_ATOP)
overflowIcon?.colorFilter = PorterDuffColorFilter(tint, PorterDuff.Mode.SRC_ATOP)
}

View file

@ -0,0 +1,79 @@
package org.signal.core.util;
import android.content.Context;
import android.content.res.Configuration;
import android.content.res.Resources;
import android.os.Build;
import androidx.annotation.NonNull;
import androidx.annotation.StringRes;
import java.util.Locale;
/**
* Allows you to detect if a string resource is readable by the user according to their language settings.
*/
public final class TranslationDetection {
private final Resources resourcesLocal;
private final Resources resourcesEn;
private final Configuration configurationLocal;
/**
* @param context Do not pass Application context, as this may not represent the users selected in-app locale.
*/
public TranslationDetection(@NonNull Context context) {
this.resourcesLocal = context.getResources();
this.configurationLocal = resourcesLocal.getConfiguration();
this.resourcesEn = ResourceUtil.getEnglishResources(context);
}
/**
* @param context Can be Application context.
* @param usersLocale Locale of user.
*/
public TranslationDetection(@NonNull Context context, @NonNull Locale usersLocale) {
this.resourcesLocal = ResourceUtil.getResources(context.getApplicationContext(), usersLocale);
this.configurationLocal = resourcesLocal.getConfiguration();
this.resourcesEn = ResourceUtil.getEnglishResources(context);
}
/**
* Returns true if any of these are true:
* - The current locale is English.
* - In a multi-locale capable device, the device supports any English locale in any position.
* - The text for the current locale does not Equal the English.
*/
public boolean textExistsInUsersLanguage(@StringRes int resId) {
if (configSupportsEnglish()) {
return true;
}
String stringEn = resourcesEn.getString(resId);
String stringLocal = resourcesLocal.getString(resId);
return !stringEn.equals(stringLocal);
}
public boolean textExistsInUsersLanguage(@StringRes int... resIds) {
for (int resId : resIds) {
if (!textExistsInUsersLanguage(resId)) {
return false;
}
}
return true;
}
protected boolean configSupportsEnglish() {
if (configurationLocal.locale.getLanguage().equals("en")) {
return true;
}
if (Build.VERSION.SDK_INT >= 24) {
Locale firstMatch = configurationLocal.getLocales().getFirstMatch(new String[]{"en"});
return firstMatch != null && firstMatch.getLanguage().equals("en");
}
return false;
}
}

View file

@ -0,0 +1,14 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.androidx
import androidx.documentfile.provider.DocumentFile
/**
* Information about a file within the storage. Useful because default [DocumentFile] implementations
* re-query info on each access.
*/
data class DocumentFileInfo(val documentFile: DocumentFile, val name: String, val size: Long)

View file

@ -0,0 +1,221 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.androidx
import android.content.Context
import android.provider.DocumentsContract
import androidx.documentfile.provider.DocumentFile
import androidx.documentfile.provider.isTreeDocumentFile
import org.signal.core.util.ThreadUtil
import org.signal.core.util.logging.Log
import org.signal.core.util.readToList
import org.signal.core.util.requireLong
import org.signal.core.util.requireNonNullString
import org.signal.core.util.requireString
import java.io.InputStream
import java.io.OutputStream
import kotlin.time.Duration.Companion.seconds
/**
* Collection of helper and optimizized operations for working with [DocumentFile]s.
*/
object DocumentFileUtil {
private val TAG = Log.tag(DocumentFileUtil::class)
private val FILE_PROJECTION = arrayOf(DocumentsContract.Document.COLUMN_DOCUMENT_ID, DocumentsContract.Document.COLUMN_DISPLAY_NAME, DocumentsContract.Document.COLUMN_SIZE)
private const val FILE_SELECTION = "${DocumentsContract.Document.COLUMN_DISPLAY_NAME} = ?"
private const val LIST_FILES_SELECTION = "${DocumentsContract.Document.COLUMN_MIME_TYPE} != ?"
private val LIST_FILES_SELECTION_ARGS = arrayOf(DocumentsContract.Document.MIME_TYPE_DIR)
private const val MAX_STORAGE_ATTEMPTS: Int = 5
private val WAIT_FOR_SCOPED_STORAGE: LongArray = longArrayOf(0, 2.seconds.inWholeMilliseconds, 10.seconds.inWholeMilliseconds, 20.seconds.inWholeMilliseconds, 30.seconds.inWholeMilliseconds)
/** Returns true if the directory represented by the [DocumentFile] has a child with [name]. */
fun DocumentFile.hasFile(name: String): Boolean {
return findFile(name) != null
}
/** Returns the [DocumentFile] for a newly created binary file or null if unable or it already exists */
fun DocumentFile.newFile(name: String): DocumentFile? {
return if (hasFile(name)) {
Log.w(TAG, "Attempt to create new file ($name) but it already exists")
null
} else {
createFile("application/octet-stream", name)
}
}
/** Returns a [DocumentFile] for directory by [name], creating it if it doesn't already exist */
fun DocumentFile.mkdirp(name: String): DocumentFile? {
return findFile(name) ?: createDirectory(name)
}
/** Open an [OutputStream] to the file represented by the [DocumentFile] */
fun DocumentFile.outputStream(context: Context): OutputStream? {
return context.contentResolver.openOutputStream(uri)
}
/** Open an [InputStream] to the file represented by the [DocumentFile] */
@JvmStatic
fun DocumentFile.inputStream(context: Context): InputStream? {
return context.contentResolver.openInputStream(uri)
}
/**
* Will attempt to find the named [file] in the [root] directory and delete it if found.
*
* @return true if found and deleted, false if the file couldn't be deleted, and null if not found
*/
fun DocumentFile.delete(context: Context, file: String): Boolean? {
return findFile(context, file)?.documentFile?.delete()
}
/**
* Will attempt to find the name [fileName] in the [root] directory and return useful information if found using
* a single [Context.getContentResolver] query.
*
* Recommend using this over [DocumentFile.findFile] to prevent excess queries for all files and names.
*
* If direct queries fail to find the file, will fallback to using [DocumentFile.findFile].
*/
fun DocumentFile.findFile(context: Context, fileName: String): DocumentFileInfo? {
val child: List<DocumentFileInfo> = if (isTreeDocumentFile()) {
val childrenUri = DocumentsContract.buildChildDocumentsUriUsingTree(uri, DocumentsContract.getDocumentId(uri))
try {
context
.contentResolver
.query(childrenUri, FILE_PROJECTION, FILE_SELECTION, arrayOf(fileName), null)
?.readToList(predicate = { it.name == fileName }) { cursor ->
val uri = DocumentsContract.buildDocumentUriUsingTree(uri, cursor.requireString(DocumentsContract.Document.COLUMN_DOCUMENT_ID))
val displayName = cursor.requireNonNullString(DocumentsContract.Document.COLUMN_DISPLAY_NAME)
val length = cursor.requireLong(DocumentsContract.Document.COLUMN_SIZE)
DocumentFileInfo(DocumentFile.fromSingleUri(context, uri)!!, displayName, length)
} ?: emptyList()
} catch (e: Exception) {
Log.d(TAG, "Unable to find file directly on ${javaClass.simpleName}, falling back to OS", e)
emptyList()
}
} else {
emptyList()
}
return if (child.size == 1) {
child[0]
} else {
Log.w(TAG, "Did not find single file, found (${child.size}), falling back to OS")
this.findFile(fileName)?.let { DocumentFileInfo(it, it.name!!, it.length()) }
}
}
/**
* List file names and sizes in the [DocumentFile] by directly querying the content resolver ourselves. The system
* implementation makes a separate query for each name and length method call and gets expensive over 1000's of files.
*
* Will fallback to the provided document file's implementation of [DocumentFile.listFiles] if unable to do it directly.
*/
fun DocumentFile.listFiles(context: Context): List<DocumentFileInfo> {
if (isTreeDocumentFile()) {
val childrenUri = DocumentsContract.buildChildDocumentsUriUsingTree(uri, DocumentsContract.getDocumentId(uri))
try {
val results = context
.contentResolver
.query(childrenUri, FILE_PROJECTION, LIST_FILES_SELECTION, LIST_FILES_SELECTION_ARGS, null)
?.use { cursor ->
val results = ArrayList<DocumentFileInfo>(cursor.count)
while (cursor.moveToNext()) {
val uri = DocumentsContract.buildDocumentUriUsingTree(uri, cursor.requireString(DocumentsContract.Document.COLUMN_DOCUMENT_ID))
val displayName = cursor.requireString(DocumentsContract.Document.COLUMN_DISPLAY_NAME)
val length = cursor.requireLong(DocumentsContract.Document.COLUMN_SIZE)
if (displayName != null) {
results.add(DocumentFileInfo(DocumentFile.fromSingleUri(context, uri)!!, displayName, length))
}
}
results
}
if (results != null) {
return results
} else {
Log.w(TAG, "Content provider returned null for query on ${javaClass.simpleName}, falling back to OS")
}
} catch (e: Exception) {
Log.d(TAG, "Unable to query files directly on ${javaClass.simpleName}, falling back to OS", e)
}
}
return listFiles()
.asSequence()
.filter { it.isFile }
.mapNotNull { file -> file.name?.let { DocumentFileInfo(file, it, file.length()) } }
.toList()
}
/**
* System implementation swallows the exception and we are having problems with the rename. This inlines the
* same call and logs the exception. Note this implementation does not update the passed in document file like
* the system implementation. Do not use the provided document file after calling this method.
*
* @return true if rename successful
*/
@JvmStatic
fun DocumentFile.renameTo(context: Context, displayName: String): Boolean {
if (isTreeDocumentFile()) {
Log.d(TAG, "Renaming document directly")
try {
val result = DocumentsContract.renameDocument(context.contentResolver, uri, displayName)
return result != null
} catch (e: Exception) {
Log.w(TAG, "Unable to rename document file, falling back to OS", e)
return renameTo(displayName)
}
} else {
return renameTo(displayName)
}
}
/**
* Historically, we've seen issues with [DocumentFile] operations not working on the first try. This
* retry loop will retry those operations with a varying backoff in attempt to make them work.
*/
@JvmStatic
fun <T> retryDocumentFileOperation(operation: DocumentFileOperation<T>): OperationResult {
var attempts = 0
var operationResult = operation.operation(attempts, MAX_STORAGE_ATTEMPTS)
while (attempts < MAX_STORAGE_ATTEMPTS && !operationResult.isSuccess()) {
ThreadUtil.sleep(WAIT_FOR_SCOPED_STORAGE[attempts])
attempts++
operationResult = operation.operation(attempts, MAX_STORAGE_ATTEMPTS)
}
return operationResult
}
/** Operation to perform in a retry loop via [retryDocumentFileOperation] that could fail based on timing */
fun interface DocumentFileOperation<T> {
fun operation(attempt: Int, maxAttempts: Int): OperationResult
}
/** Result of a single operation in a retry loop via [retryDocumentFileOperation] */
sealed interface OperationResult {
fun isSuccess(): Boolean {
return this is Success
}
/** The operation completed successful */
data class Success(val value: Boolean) : OperationResult
/** Retry the operation */
data object Retry : OperationResult
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import android.app.Activity
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
/**
* Variant interface for the BillingApi.
*/
interface BillingApi {
/**
* Listenable stream of billing purchase results. It's up to the user
* to call queryPurchases after subscription.
*/
fun getBillingPurchaseResults(): Flow<BillingPurchaseResult> = emptyFlow()
suspend fun getApiAvailability(): BillingResponseCode = BillingResponseCode.FEATURE_NOT_SUPPORTED
/**
* Queries the Billing API for product pricing. This value should be cached by
* the implementor for 24 hours.
*/
suspend fun queryProduct(): BillingProduct? = null
/**
* Queries the user's current purchases. This enqueues a check and will
* propagate it to the normal callbacks in the api.
*/
suspend fun queryPurchases(): BillingPurchaseResult = BillingPurchaseResult.None
suspend fun launchBillingFlow(activity: Activity) = Unit
/**
* Empty implementation, to be used when play services are available but
* GooglePlayBillingApi is not available.
*/
object Empty : BillingApi
}

View file

@ -0,0 +1,28 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import android.content.Context
/**
* Provides a dependency model by which the billing api can request different resources.
*/
interface BillingDependencies {
/**
* Application context
*/
val context: Context
/**
* Get the product id from the donations configuration object.
*/
suspend fun getProductId(): String
/**
* Get the base plan id from the donations configuration object.
*/
suspend fun getBasePlanId(): String
}

View file

@ -0,0 +1,10 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
class BillingError(
val billingResponseCode: Int
) : Exception("$billingResponseCode")

View file

@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import org.signal.core.util.money.FiatMoney
/**
* Represents a purchasable product from the Google Play Billing API
*/
data class BillingProduct(
val price: FiatMoney
)

View file

@ -0,0 +1,41 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
/**
* Sealed class hierarchy representing the different success
* and error states of google play billing purchases.
*/
sealed interface BillingPurchaseResult {
data class Success(
val purchaseState: BillingPurchaseState,
val purchaseToken: String,
val isAcknowledged: Boolean,
val purchaseTime: Long,
val isAutoRenewing: Boolean
) : BillingPurchaseResult {
override fun toString(): String {
return """
BillingPurchaseResult {
purchaseState: $purchaseState
purchaseToken: <redacted>
purchaseTime: $purchaseTime
isAcknowledged: $isAcknowledged
isAutoRenewing: $isAutoRenewing
}
""".trimIndent()
}
}
data object UserCancelled : BillingPurchaseResult
data object None : BillingPurchaseResult
data object TryAgainLater : BillingPurchaseResult
data object AlreadySubscribed : BillingPurchaseResult
data object FeatureNotSupported : BillingPurchaseResult
data object GenericError : BillingPurchaseResult
data object NetworkError : BillingPurchaseResult
data object BillingUnavailable : BillingPurchaseResult
}

View file

@ -0,0 +1,15 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
/**
* BillingPurchaseState which aligns with the Google Play Billing purchased state.
*/
enum class BillingPurchaseState {
UNSPECIFIED,
PURCHASED,
PENDING
}

View file

@ -0,0 +1,42 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.billing
import org.signal.core.util.logging.Log
enum class BillingResponseCode(val code: Int) {
UNKNOWN(code = Int.MIN_VALUE),
SERVICE_TIMEOUT(code = -3),
FEATURE_NOT_SUPPORTED(code = -2),
SERVICE_DISCONNECTED(code = -1),
OK(code = 0),
USER_CANCELED(code = 1),
SERVICE_UNAVAILABLE(code = 2),
BILLING_UNAVAILABLE(code = 3),
ITEM_UNAVAILABLE(code = 4),
DEVELOPER_ERROR(code = 5),
ERROR(code = 6),
ITEM_ALREADY_OWNED(code = 7),
ITEM_NOT_OWNED(code = 8),
NETWORK_ERROR(code = 12);
val isSuccess: Boolean get() = this == OK
companion object {
private val TAG = Log.tag(BillingResponseCode::class)
fun fromBillingLibraryResponseCode(responseCode: Int): BillingResponseCode {
val code = BillingResponseCode.entries.firstOrNull { responseCode == it.code } ?: UNKNOWN
if (code == UNKNOWN) {
Log.w(TAG, "Unknown response code: $code")
}
return code
}
}
}

View file

@ -0,0 +1,132 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import android.os.Debug
import android.os.Looper
import androidx.annotation.MainThread
import org.signal.core.util.ThreadUtil
import org.signal.core.util.logging.Log
import java.lang.IllegalStateException
import java.lang.RuntimeException
import java.text.SimpleDateFormat
import java.util.Date
import java.util.Locale
import kotlin.time.Duration
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds
/**
* Attempts to detect ANR's by posting runnables to the main thread and detecting if they've been run within the [anrThreshold].
* If an ANR is detected, it is logged, and the [anrSaver] is called with the series of thread dumps that were taken of the main thread.
*
* The detection of an ANR will cause an internal user to crash.
*/
object AnrDetector {
private val TAG = Log.tag(AnrDetector::class.java)
private var thread: AnrDetectorThread? = null
private val dateFormat = SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS zzz", Locale.US)
@JvmStatic
@MainThread
fun start(anrThreshold: Long = 5.seconds.inWholeMilliseconds, isInternal: () -> Boolean, anrSaver: (String) -> Unit) {
thread?.end()
thread = null
thread = AnrDetectorThread(anrThreshold.milliseconds, isInternal, anrSaver)
thread!!.start()
}
@JvmStatic
@MainThread
fun stop() {
thread?.end()
thread = null
}
private class AnrDetectorThread(
private val anrThreshold: Duration,
private val isInternal: () -> Boolean,
private val anrSaver: (String) -> Unit
) : Thread("signal-anr") {
@Volatile
private var uiRan = false
private val uiRunnable = Runnable {
uiRan = true
}
@Volatile
private var stopped = false
override fun run() {
while (!stopped) {
uiRan = false
ThreadUtil.postToMain(uiRunnable)
val intervalCount = 5
val intervalDuration = anrThreshold.inWholeMilliseconds / intervalCount
if (intervalDuration == 0L) {
throw IllegalStateException("ANR threshold is too small!")
}
val dumps = mutableListOf<String>()
for (i in 1..intervalCount) {
if (stopped) {
Log.i(TAG, "Thread shutting down during intervals.")
return
}
ThreadUtil.sleep(intervalDuration)
if (!uiRan) {
dumps += getMainThreadDump()
} else {
dumps.clear()
}
}
if (!uiRan && !Debug.isDebuggerConnected() && !Debug.waitingForDebugger()) {
Log.w(TAG, "Failed to post to main in ${anrThreshold.inWholeMilliseconds} ms! Likely ANR!")
val dumpString = dumps.joinToString(separator = "\n\n")
Log.w(TAG, "Main thread dumps:\n$dumpString")
ThreadUtil.cancelRunnableOnMain(uiRunnable)
anrSaver(dumpString)
if (isInternal()) {
Log.e(TAG, "Internal user -- crashing!")
throw SignalAnrException()
}
}
dumps.clear()
}
Log.i(TAG, "Thread shutting down.")
}
fun end() {
stopped = true
}
private fun getMainThreadDump(): String {
val dump: Map<Thread, Array<StackTraceElement>> = Thread.getAllStackTraces()
val mainThread = Looper.getMainLooper().thread
val date = dateFormat.format(Date())
val dumpString = dump[mainThread]?.joinToString(separator = "\n") ?: "Not available."
return "--- $date:\n$dumpString"
}
}
private class SignalAnrException : RuntimeException()
}

View file

@ -0,0 +1,146 @@
package org.signal.core.util.concurrent
import android.os.Handler
import org.signal.core.util.logging.Log
import java.util.concurrent.ExecutorService
import java.util.concurrent.ThreadPoolExecutor
/**
* A class that polls active threads at a set interval and logs when multiple threads are BLOCKED.
*/
class DeadlockDetector(private val handler: Handler, private val pollingInterval: Long) {
private var running = false
private val previouslyBlocked: MutableSet<Long> = mutableSetOf()
private val waitingStates: Set<Thread.State> = setOf(Thread.State.WAITING, Thread.State.TIMED_WAITING)
@Volatile
var lastThreadDump: Map<Thread, Array<StackTraceElement>>? = null
@Volatile
var lastThreadDumpTime: Long = -1
fun start() {
Log.d(TAG, "Beginning deadlock monitoring.")
running = true
handler.postDelayed(this::poll, pollingInterval)
}
fun stop() {
Log.d(TAG, "Ending deadlock monitoring.")
running = false
handler.removeCallbacksAndMessages(null)
}
private fun poll() {
val time: Long = System.currentTimeMillis()
val threads: Map<Thread, Array<StackTraceElement>> = Thread.getAllStackTraces()
val blocked: Map<Thread, Array<StackTraceElement>> = threads
.filter { entry ->
val thread: Thread = entry.key
val stack: Array<StackTraceElement> = entry.value
thread.state == Thread.State.BLOCKED || (thread.state.isWaiting() && stack.hasPotentialLock())
}
.filter { entry -> !BLOCK_BLOCKLIST.contains(entry.key.name) }
val blockedIds: Set<Long> = blocked.keys.map(Thread::getId).toSet()
val stillBlocked: Set<Long> = blockedIds.intersect(previouslyBlocked)
if (blocked.size > 1) {
Log.w(TAG, buildLogString("Found multiple blocked threads! Possible deadlock.", blocked))
lastThreadDump = threads
lastThreadDumpTime = time
} else if (stillBlocked.isNotEmpty()) {
val stillBlockedMap: Map<Thread, Array<StackTraceElement>> = stillBlocked
.map { blockedId ->
val key: Thread = blocked.keys.first { it.id == blockedId }
val value: Array<StackTraceElement> = blocked[key]!!
Pair(key, value)
}
.toMap()
Log.w(TAG, buildLogString("Found a long block! Blocked for at least $pollingInterval ms.", stillBlockedMap))
lastThreadDump = threads
lastThreadDumpTime = time
}
val fullExecutors: List<ExecutorInfo> = CHECK_FULLNESS_EXECUTORS.filter { isExecutorFull(it.executor) }
if (fullExecutors.isNotEmpty()) {
fullExecutors.forEach { executorInfo ->
val fullMap: Map<Thread, Array<StackTraceElement>> = threads
.filter { it.key.name.startsWith(executorInfo.namePrefix) }
.toMap()
val executor: ThreadPoolExecutor = executorInfo.executor as ThreadPoolExecutor
Log.w(TAG, buildLogString("Found a full executor! ${executor.activeCount}/${executor.maximumPoolSize} threads active with ${executor.queue.size} tasks queued.", fullMap))
}
lastThreadDump = threads
lastThreadDumpTime = time
}
previouslyBlocked.clear()
previouslyBlocked.addAll(blockedIds)
if (running) {
handler.postDelayed(this::poll, pollingInterval)
}
}
private data class ExecutorInfo(
val executor: ExecutorService,
val namePrefix: String
)
private fun Thread.State.isWaiting(): Boolean {
return waitingStates.contains(this)
}
private fun Array<StackTraceElement>.hasPotentialLock(): Boolean {
return any {
it.methodName.startsWith("lock") || (it.methodName.startsWith("waitForConnection") && !it.className.contains("IncomingMessageObserver"))
}
}
companion object {
private val TAG = Log.tag(DeadlockDetector::class.java)
private val CHECK_FULLNESS_EXECUTORS: Set<ExecutorInfo> = setOf(
ExecutorInfo(SignalExecutors.BOUNDED, "signal-bounded-"),
ExecutorInfo(SignalExecutors.BOUNDED_IO, "signal-io-bounded")
)
private const val CONCERNING_QUEUE_THRESHOLD = 4
private val BLOCK_BLOCKLIST = setOf("HeapTaskDaemon")
private fun buildLogString(description: String, blocked: Map<Thread, Array<StackTraceElement>>): String {
val stringBuilder = StringBuilder()
stringBuilder.append(description).append("\n")
for (entry in blocked) {
stringBuilder.append("-- [${entry.key.id}] ${entry.key.name} | ${entry.key.state}\n")
val stackTrace: Array<StackTraceElement> = entry.value
for (element in stackTrace) {
stringBuilder.append("$element\n")
}
stringBuilder.append("\n")
}
return stringBuilder.toString()
}
private fun isExecutorFull(executor: ExecutorService): Boolean {
return if (executor is ThreadPoolExecutor) {
executor.queue.size > CONCERNING_QUEUE_THRESHOLD
} else {
false
}
}
}
}

View file

@ -0,0 +1,87 @@
package org.signal.core.util.concurrent;
import androidx.annotation.NonNull;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
/**
* A serial executor that will order pending tasks by a specified priority, and will only keep a single task of a given priority, preferring the latest.
*
* So imagine a world where the following tasks were all enqueued (meaning they're all waiting to be executed):
*
* execute(0, runnableA);
* execute(3, runnableC1);
* execute(3, runnableC2);
* execute(2, runnableB);
*
* You'd expect the execution order to be:
* - runnableC2
* - runnableB
* - runnableA
*
* (We order by priority, and C1 was replaced by C2)
*/
public final class LatestPrioritizedSerialExecutor {
private final Queue<PriorityRunnable> tasks;
private final Executor executor;
private Runnable active;
public LatestPrioritizedSerialExecutor(@NonNull Executor executor) {
this.executor = executor;
this.tasks = new PriorityQueue<>();
}
/**
* Execute with a priority. Higher priorities are executed first.
*/
public synchronized void execute(int priority, @NonNull Runnable r) {
Iterator<PriorityRunnable> iterator = tasks.iterator();
while (iterator.hasNext()) {
if (iterator.next().getPriority() == priority) {
iterator.remove();
}
}
tasks.offer(new PriorityRunnable(priority) {
@Override
public void run() {
try {
r.run();
} finally {
scheduleNext();
}
}
});
if (active == null) {
scheduleNext();
}
}
private synchronized void scheduleNext() {
if ((active = tasks.poll()) != null) {
executor.execute(active);
}
}
private abstract static class PriorityRunnable implements Runnable, Comparable<PriorityRunnable> {
private final int priority;
public PriorityRunnable(int priority) {
this.priority = priority;
}
public int getPriority() {
return priority;
}
@Override
public final int compareTo(PriorityRunnable other) {
return other.getPriority() - this.getPriority();
}
}
}

View file

@ -0,0 +1,23 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Observer
import io.reactivex.rxjava3.subjects.BehaviorSubject
/**
* An Observer that provides instant access to the latest emitted value.
* Basically a read-only version of [BehaviorSubject].
*/
class LatestValueObservable<T : Any>(private val subject: BehaviorSubject<T>) : Observable<T>() {
val value: T?
get() = subject.value
override fun subscribeActual(observer: Observer<in T>) {
subject.subscribe(observer)
}
}

View file

@ -0,0 +1,48 @@
package org.signal.core.util.concurrent
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.Lifecycle
import androidx.lifecycle.LifecycleOwner
import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.disposables.Disposable
/**
* A lifecycle-aware [Disposable] that, after being bound to a lifecycle, will automatically dispose all contained disposables at the proper time.
*/
class LifecycleDisposable : DefaultLifecycleObserver {
val disposables: CompositeDisposable = CompositeDisposable()
fun bindTo(lifecycleOwner: LifecycleOwner): LifecycleDisposable {
return bindTo(lifecycleOwner.lifecycle)
}
fun bindTo(lifecycle: Lifecycle): LifecycleDisposable {
lifecycle.addObserver(this)
return this
}
fun add(disposable: Disposable): LifecycleDisposable {
disposables.add(disposable)
return this
}
fun addAll(vararg disposable: Disposable): LifecycleDisposable {
disposables.addAll(*disposable)
return this
}
fun clear() {
disposables.clear()
}
override fun onDestroy(owner: LifecycleOwner) {
owner.lifecycle.removeObserver(this)
disposables.clear()
}
operator fun plusAssign(disposable: Disposable) {
add(disposable)
}
}
fun Disposable.addTo(lifecycleDisposable: LifecycleDisposable): Disposable = apply { lifecycleDisposable.add(this) }

View file

@ -0,0 +1,36 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutorService
import java.util.concurrent.Semaphore
object LimitedWorker {
/**
* Call [worker] on a thread from [executor] for each element in [input] using only up to [maxThreads] concurrently.
*
* This method will block until all work is completed. There is no guarantee that the same threads
* will be used but that only up to [maxThreads] will be actively doing work.
*/
@JvmStatic
fun <T> execute(executor: ExecutorService, maxThreads: Int, input: Collection<T>, worker: (T) -> Unit) {
val doneWorkLatch = CountDownLatch(input.size)
val semaphore = Semaphore(maxThreads)
for (work in input) {
semaphore.acquire()
executor.execute {
worker(work)
semaphore.release()
doneWorkLatch.countDown()
}
}
doneWorkLatch.await()
}
}

View file

@ -0,0 +1,40 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Maybe
import io.reactivex.rxjava3.exceptions.Exceptions
import io.reactivex.rxjava3.plugins.RxJavaPlugins
/**
* Kotlin 1.8 started respecting RxJava nullability annotations but RxJava has some oddities where it breaks those rules.
* This essentially re-implements [Maybe.fromCallable] with an emitter so we don't have to do it everywhere ourselves.
*/
object MaybeCompat {
fun <T : Any> fromCallable(callable: () -> T?): Maybe<T> {
return Maybe.create { emitter ->
val result = try {
callable()
} catch (e: Throwable) {
Exceptions.throwIfFatal(e)
if (!emitter.isDisposed) {
emitter.onError(e)
} else {
RxJavaPlugins.onError(e)
}
return@create
}
if (!emitter.isDisposed) {
if (result == null) {
emitter.onComplete()
} else {
emitter.onSuccess(result)
}
}
}
}
}

View file

@ -0,0 +1,75 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("RxExtensions")
package org.signal.core.util.concurrent
import androidx.lifecycle.LifecycleOwner
import io.reactivex.rxjava3.core.Completable
import io.reactivex.rxjava3.core.Flowable
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.kotlin.addTo
import io.reactivex.rxjava3.kotlin.subscribeBy
import io.reactivex.rxjava3.subjects.Subject
fun <T : Any> Flowable<T>.observe(viewLifecycleOwner: LifecycleOwner, onNext: (T) -> Unit) {
val lifecycleDisposable = LifecycleDisposable()
lifecycleDisposable.bindTo(viewLifecycleOwner)
lifecycleDisposable += subscribeBy(onNext = onNext)
}
fun Completable.observe(viewLifecycleOwner: LifecycleOwner, onComplete: () -> Unit) {
val lifecycleDisposable = LifecycleDisposable()
lifecycleDisposable.bindTo(viewLifecycleOwner)
lifecycleDisposable += subscribeBy(onComplete = onComplete)
}
fun <S : Subject<T>, T : Any> Observable<T>.subscribeWithSubject(
subject: S,
disposables: CompositeDisposable
): S {
subscribeBy(
onNext = subject::onNext,
onError = subject::onError,
onComplete = subject::onComplete
).addTo(disposables)
return subject
}
fun <S : Subject<T>, T : Any> Single<T>.subscribeWithSubject(
subject: S,
disposables: CompositeDisposable
): S {
subscribeBy(
onSuccess = {
subject.onNext(it)
subject.onComplete()
},
onError = subject::onError
).addTo(disposables)
return subject
}
/**
* Skips the first item emitted from the flowable, but only if it matches the provided [predicate].
*/
fun <T : Any> Flowable<T>.skipFirstIf(predicate: (T) -> Boolean): Flowable<T> {
return this
.scan(Pair<Boolean, T?>(false, null)) { acc, item ->
val firstItemInList = !acc.first
if (firstItemInList && predicate(item)) {
true to null
} else {
true to item
}
}
.filter { it.second != null }
.map { it.second!! }
}

View file

@ -0,0 +1,40 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
/**
* [Dispatchers] wrapper to allow tests to inject test dispatchers.
*/
object SignalDispatchers {
private var dispatcherProvider: DispatcherProvider = DefaultDispatcherProvider
fun setDispatcherProvider(dispatcherProvider: DispatcherProvider = DefaultDispatcherProvider) {
this.dispatcherProvider = dispatcherProvider
}
val Main get() = dispatcherProvider.main
val IO get() = dispatcherProvider.io
val Default get() = dispatcherProvider.default
val Unconfined get() = dispatcherProvider.unconfined
interface DispatcherProvider {
val main: CoroutineDispatcher
val io: CoroutineDispatcher
val default: CoroutineDispatcher
val unconfined: CoroutineDispatcher
}
private object DefaultDispatcherProvider : DispatcherProvider {
override val main: CoroutineDispatcher = Dispatchers.Main
override val io: CoroutineDispatcher = Dispatchers.IO
override val default: CoroutineDispatcher = Dispatchers.Default
override val unconfined: CoroutineDispatcher = Dispatchers.Unconfined
}
}

View file

@ -0,0 +1,105 @@
package org.signal.core.util.concurrent;
import android.os.HandlerThread;
import android.os.Process;
import androidx.annotation.NonNull;
import org.signal.core.util.ThreadUtil;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
public final class SignalExecutors {
public static final ExecutorService UNBOUNDED = Executors.newCachedThreadPool(new NumberedThreadFactory("signal-unbounded", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService BOUNDED = Executors.newFixedThreadPool(4, new NumberedThreadFactory("signal-bounded", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService SERIAL = Executors.newSingleThreadExecutor(new NumberedThreadFactory("signal-serial", ThreadUtil.PRIORITY_BACKGROUND_THREAD));
public static final ExecutorService BOUNDED_IO = newCachedBoundedExecutor("signal-io-bounded", ThreadUtil.PRIORITY_IMPORTANT_BACKGROUND_THREAD, 1, 32, 30);
private SignalExecutors() {}
public static ExecutorService newCachedSingleThreadExecutor(final String name, int priority) {
ThreadPoolExecutor executor = new ThreadPoolExecutor(1, 1, 15, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), r -> new Thread(r, name) {
@Override public void run() {
Process.setThreadPriority(priority);
super.run();
}
});
executor.allowCoreThreadTimeOut(true);
return executor;
}
/**
* ThreadPoolExecutor will only create a new thread if the provided queue returns false from
* offer(). That means if you give it an unbounded queue, it'll only ever create 1 thread, no
* matter how long the queue gets.
* <p>
* But if you bound the queue and submit more runnables than there are threads, your task is
* rejected and throws an exception.
* <p>
* So we make a queue that will always return false if it's non-empty to ensure new threads get
* created. Then, if a task gets rejected, we simply add it to the queue.
*/
public static ExecutorService newCachedBoundedExecutor(final String name, int priority, int minThreads, int maxThreads, int timeoutSeconds) {
ThreadPoolExecutor threadPool = new ThreadPoolExecutor(minThreads,
maxThreads,
timeoutSeconds,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>() {
@Override
public boolean offer(Runnable runnable) {
if (isEmpty()) {
return super.offer(runnable);
} else {
return false;
}
}
}, new NumberedThreadFactory(name, priority));
threadPool.setRejectedExecutionHandler((runnable, executor) -> {
try {
executor.getQueue().put(runnable);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
});
return threadPool;
}
public static HandlerThread getAndStartHandlerThread(@NonNull String name, int priority) {
HandlerThread handlerThread = new HandlerThread(name, priority);
handlerThread.start();
return handlerThread;
}
public static class NumberedThreadFactory implements ThreadFactory {
private final int priority;
private final String baseName;
private final AtomicInteger counter;
public NumberedThreadFactory(@NonNull String baseName, int priority) {
this.priority = priority;
this.baseName = baseName;
this.counter = new AtomicInteger();
}
@Override
public Thread newThread(@NonNull Runnable r) {
return new Thread(r, baseName + "-" + counter.getAndIncrement()) {
@Override
public void run() {
Process.setThreadPriority(priority);
super.run();
}
};
}
}
}

View file

@ -0,0 +1,103 @@
package org.signal.core.util.concurrent;
import android.os.AsyncTask;
import androidx.annotation.NonNull;
import androidx.lifecycle.Lifecycle;
import androidx.lifecycle.LifecycleEventObserver;
import androidx.lifecycle.LifecycleOwner;
import org.signal.core.util.ThreadUtil;
import org.signal.core.util.concurrent.SignalExecutors;
import java.util.concurrent.Executor;
import io.reactivex.rxjava3.observers.DefaultObserver;
public class SimpleTask {
/**
* Runs a task in the background and passes the result of the computation to a task that is run
* on the main thread. Will only invoke the {@code foregroundTask} if the provided {@link Lifecycle}
* is in a valid (i.e. visible) state at that time. In this way, it is very similar to
* {@link AsyncTask}, but is safe in that you can guarantee your task won't be called when your
* view is in an invalid state.
*/
public static <E> void run(@NonNull Lifecycle lifecycle, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
if (!isValid(lifecycle)) {
return;
}
SignalExecutors.BOUNDED.execute(() -> {
final E result = backgroundTask.run();
if (isValid(lifecycle)) {
ThreadUtil.runOnMain(() -> {
if (isValid(lifecycle)) {
foregroundTask.run(result);
}
});
}
});
}
/**
* Runs a task in the background and passes the result of the computation to a task that is run
* on the main thread. Will only invoke the {@code foregroundTask} if the provided {@link Lifecycle}
* is or enters in the future a valid (i.e. visible) state. In this way, it is very similar to
* {@link AsyncTask}, but is safe in that you can guarantee your task won't be called when your
* view is in an invalid state.
*/
public static <E> void runWhenValid(@NonNull Lifecycle lifecycle, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
lifecycle.addObserver(new LifecycleEventObserver() {
@Override public void onStateChanged(@NonNull LifecycleOwner lifecycleOwner, @NonNull Lifecycle.Event event) {
if (isValid(lifecycle)) {
lifecycle.removeObserver(this);
SignalExecutors.BOUNDED.execute(() -> {
final E result = backgroundTask.run();
if (isValid(lifecycle)) {
ThreadUtil.runOnMain(() -> {
if (isValid(lifecycle)) {
foregroundTask.run(result);
}
});
}
});
}
}
});
}
/**
* Runs a task in the background and passes the result of the computation to a task that is run on
* the main thread. Essentially {@link AsyncTask}, but lambda-compatible.
*/
public static <E> void run(@NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
run(SignalExecutors.BOUNDED, backgroundTask, foregroundTask);
}
/**
* Runs a task on the specified {@link Executor} and passes the result of the computation to a
* task that is run on the main thread. Essentially {@link AsyncTask}, but lambda-compatible.
*/
public static <E> void run(@NonNull Executor executor, @NonNull BackgroundTask<E> backgroundTask, @NonNull ForegroundTask<E> foregroundTask) {
executor.execute(() -> {
final E result = backgroundTask.run();
ThreadUtil.runOnMain(() -> foregroundTask.run(result));
});
}
private static boolean isValid(@NonNull Lifecycle lifecycle) {
return lifecycle.getCurrentState().isAtLeast(Lifecycle.State.CREATED);
}
public interface BackgroundTask<E> {
E run();
}
public interface ForegroundTask<E> {
void run(E result);
}
}

View file

@ -0,0 +1,60 @@
package org.signal.core.util.logging
import android.annotation.SuppressLint
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executor
import java.util.concurrent.Executors
@SuppressLint("LogNotSignal")
object AndroidLogger : Log.Logger() {
private val serialExecutor: Executor = Executors.newSingleThreadExecutor { Thread(it, "signal-logcat") }
override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.v(tag, message.scrub(), t)
}
}
override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.d(tag, message.scrub(), t)
}
}
override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.i(tag, message.scrub(), t)
}
}
override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.w(tag, message.scrub(), t)
}
}
override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
serialExecutor.execute {
android.util.Log.e(tag, message.scrub(), t)
}
}
override fun flush() {
val latch = CountDownLatch(1)
serialExecutor.execute {
latch.countDown()
}
try {
latch.await()
} catch (e: InterruptedException) {
android.util.Log.w("AndroidLogger", "Interrupted while waiting for flush()", e)
}
}
private fun String?.scrub(): String? {
return this?.let { Scrubber.scrub(it).toString() }
}
}

View file

@ -0,0 +1,106 @@
package org.signal.core.util.money;
import androidx.annotation.NonNull;
import java.math.BigDecimal;
import java.text.NumberFormat;
import java.util.Currency;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
public class FiatMoney {
private static final Set<String> SPECIAL_CASE_MULTIPLICANDS = new HashSet<>() {{
add("UGX");
add("ISK");
}};
private final BigDecimal amount;
private final Currency currency;
private final long timestamp;
public FiatMoney(@NonNull BigDecimal amount, @NonNull Currency currency) {
this(amount, currency, 0);
}
public FiatMoney(@NonNull BigDecimal amount, @NonNull Currency currency, long timestamp) {
this.amount = amount;
this.currency = currency;
this.timestamp = timestamp;
}
public @NonNull BigDecimal getAmount() {
return amount;
}
public @NonNull Currency getCurrency() {
return currency;
}
public long getTimestamp() {
return timestamp;
}
/**
* @return amount, rounded to the default fractional amount.
*/
public @NonNull String getDefaultPrecisionString() {
return getDefaultPrecisionString(Locale.getDefault());
}
/**
* @return amount, rounded to the default fractional amount.
*/
public @NonNull String getDefaultPrecisionString(@NonNull Locale locale) {
NumberFormat formatter = NumberFormat.getInstance(locale);
formatter.setMinimumFractionDigits(currency.getDefaultFractionDigits());
formatter.setGroupingUsed(false);
return formatter.format(amount);
}
/**
* Note: This special cases SPECIAL_CASE_MULTIPLICANDS members to act as two decimal.
*
* @return amount, in smallest possible units (cents, yen, etc.)
*/
public @NonNull String getMinimumUnitPrecisionString() {
NumberFormat formatter = NumberFormat.getInstance(Locale.US);
formatter.setMaximumFractionDigits(0);
formatter.setGroupingUsed(false);
String currencyCode = currency.getCurrencyCode();
BigDecimal multiplicand = BigDecimal.TEN.pow(SPECIAL_CASE_MULTIPLICANDS.contains(currencyCode) ? 2 : currency.getDefaultFractionDigits());
return formatter.format(amount.multiply(multiplicand));
}
/**
* Transforms the given currency / amount pair from a signal network amount to a FiatMoney, accounting for the special
* cased multiplicands for ISK and UGX
*/
public static @NonNull FiatMoney fromSignalNetworkAmount(@NonNull BigDecimal amount, @NonNull Currency currency) {
String currencyCode = currency.getCurrencyCode();
int shift = SPECIAL_CASE_MULTIPLICANDS.contains(currencyCode) ? 2: currency.getDefaultFractionDigits();
BigDecimal shiftedAmount = amount.movePointLeft(shift);
return new FiatMoney(shiftedAmount, currency);
}
public static boolean equals(FiatMoney left, FiatMoney right) {
return Objects.equals(left.amount, right.amount) &&
Objects.equals(left.currency, right.currency) &&
Objects.equals(left.timestamp, right.timestamp);
}
@Override
public String toString() {
return "FiatMoney{" +
"amount=" + amount +
", currency=" + currency +
", timestamp=" + timestamp +
'}';
}
}

View file

@ -0,0 +1,24 @@
package org.signal.core.util.money
import java.util.Currency
/**
* Utility methods for java.util.Currency
*
* This is prefixed with "Platform" as there are several different Currency classes
* available in the app, and this utility class is specifically for dealing with
* java.util.Currency
*/
object PlatformCurrencyUtil {
val USD: Currency = Currency.getInstance("USD")
/**
* Note: Adding this as an extension method of Currency causes some confusion in
* AndroidStudio due to a separate Currency class from the AndroidSDK having
* an extension method of the same signature.
*/
fun getAvailableCurrencyCodes(): Set<String> {
return Currency.getAvailableCurrencies().map { it.currencyCode }.toSet()
}
}

View file

@ -0,0 +1,236 @@
package org.signal.core.util.tracing;
import android.os.SystemClock;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import okio.ByteString;
/**
* A class to create Perfetto-compatible traces. Currently keeps the entire trace in memory to
* avoid weirdness with synchronizing to disk.
* <p>
* Some general info on how the Perfetto format works:
* - The file format is just a Trace proto (see Trace.proto)
* - The Trace proto is just a series of TracePackets
* - TracePackets can describe:
* - Threads
* - Start of a method
* - End of a method
* - (And a bunch of other stuff that's not relevant to use at this point)
* <p>
* We keep a circular buffer of TracePackets for method calls, and we keep a separate list of
* TracePackets for threads so we don't lose any of those.
* <p>
* Serializing is just a matter of throwing all the TracePackets we have into a proto.
* <p>
* Note: This class aims to be largely-thread-safe, but prioritizes speed and memory efficiency
* above all else. These methods are going to be called very quickly from every thread imaginable,
* and we want to create as little overhead as possible. The idea being that it's ok if we don't,
* for example, keep a perfect circular buffer size if it allows us to reduce overhead. The only
* cost of screwing up would be dropping a trace packet or something, which, while sad, won't affect
* how the app functions
*/
public final class Tracer {
public static final class TrackId {
public static final long DB_LOCK = -8675309;
private static final String DB_LOCK_NAME = "Database Lock";
}
private static final Tracer INSTANCE = new Tracer();
private static final int TRUSTED_SEQUENCE_ID = 1;
private static final byte[] SYNCHRONIZATION_MARKER = toByteArray(UUID.fromString("82477a76-b28d-42ba-81dc-33326d57a079"));
private static final long SYNCHRONIZATION_INTERVAL = TimeUnit.SECONDS.toNanos(3);
private final Clock clock;
private final Map<Long, TracePacket> threadPackets;
private final Queue<TracePacket> eventPackets;
private final AtomicInteger eventCount;
private long lastSyncTime;
private long maxBufferSize;
private Tracer() {
this.clock = SystemClock::elapsedRealtimeNanos;
this.threadPackets = new ConcurrentHashMap<>();
this.eventPackets = new ConcurrentLinkedQueue<>();
this.eventCount = new AtomicInteger(0);
this.maxBufferSize = 3_500;
}
public static @NonNull Tracer getInstance() {
return INSTANCE;
}
public void setMaxBufferSize(long maxBufferSize) {
this.maxBufferSize = maxBufferSize;
}
public void start(@NonNull String methodName) {
start(methodName, Thread.currentThread().getId(), null);
}
public void start(@NonNull String methodName, long trackId) {
start(methodName, trackId, null);
}
public void start(@NonNull String methodName, @NonNull String key, @Nullable String value) {
start(methodName, Thread.currentThread().getId(), key, value);
}
public void start(@NonNull String methodName, long trackId, @NonNull String key, @Nullable String value) {
start(methodName, trackId, Collections.singletonMap(key, value));
}
public void start(@NonNull String methodName, @Nullable Map<String, String> values) {
start(methodName, Thread.currentThread().getId(), values);
}
public void start(@NonNull String methodName, long trackId, @Nullable Map<String, String> values) {
long time = clock.getTimeNanos();
if (time - lastSyncTime > SYNCHRONIZATION_INTERVAL) {
addPacket(forSynchronization(time));
lastSyncTime = time;
}
if (!threadPackets.containsKey(trackId)) {
threadPackets.put(trackId, forTrackId(trackId));
}
addPacket(forMethodStart(methodName, time, trackId, values));
}
public void end(@NonNull String methodName) {
addPacket(forMethodEnd(methodName, clock.getTimeNanos(), Thread.currentThread().getId()));
}
public void end(@NonNull String methodName, long trackId) {
addPacket(forMethodEnd(methodName, clock.getTimeNanos(), trackId));
}
public @NonNull byte[] serialize() {
List<TracePacket> packets = new ArrayList<>();
packets.addAll(threadPackets.values());
packets.addAll(eventPackets);
packets.add(forSynchronization(clock.getTimeNanos()));
return new Trace.Builder().packet(packets).build().encode();
}
/**
* Attempts to add a packet to our list while keeping the size of our circular buffer in-check.
* The tracking of the event count is not perfectly thread-safe, but doing it in a thread-safe
* way would likely involve adding a lock, which we really don't want to do, since it'll add
* unnecessary overhead.
* <p>
* Note that we keep track of the event count separately because
* {@link ConcurrentLinkedQueue#size()} is NOT a constant-time operation.
*/
private void addPacket(@NonNull TracePacket packet) {
eventPackets.add(packet);
int size = eventCount.incrementAndGet();
for (int i = size; i > maxBufferSize; i--) {
eventPackets.poll();
eventCount.decrementAndGet();
}
}
private TracePacket forTrackId(long id) {
if (id == TrackId.DB_LOCK) {
return forTrack(id, TrackId.DB_LOCK_NAME);
} else {
Thread currentThread = Thread.currentThread();
return forTrack(currentThread.getId(), currentThread.getName());
}
}
private static TracePacket forTrack(long id, String name) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.track_descriptor(new TrackDescriptor.Builder()
.uuid(id)
.name(name).build())
.build();
}
private static TracePacket forMethodStart(@NonNull String name, long time, long threadId, @Nullable Map<String, String> values) {
TrackEvent.Builder event = new TrackEvent.Builder()
.track_uuid(threadId)
.name(name)
.type(TrackEvent.Type.TYPE_SLICE_BEGIN);
List<DebugAnnotation> debugAnnotations = new LinkedList<>();
if (values != null) {
for (Map.Entry<String, String> entry : values.entrySet()) {
debugAnnotations.add(debugAnnotation(entry.getKey(), entry.getValue()));
}
}
event.debug_annotations(debugAnnotations);
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.track_event(event.build())
.build();
}
private static DebugAnnotation debugAnnotation(@NonNull String key, @Nullable String value) {
return new DebugAnnotation.Builder()
.name(key)
.string_value(value != null ? value : "")
.build();
}
private static TracePacket forMethodEnd(@NonNull String name, long time, long threadId) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.track_event(new TrackEvent.Builder()
.track_uuid(threadId)
.name(name)
.type(TrackEvent.Type.TYPE_SLICE_END)
.build())
.build();
}
private static TracePacket forSynchronization(long time) {
return new TracePacket.Builder()
.trusted_packet_sequence_id(TRUSTED_SEQUENCE_ID)
.timestamp(time)
.synchronization_marker(ByteString.of(SYNCHRONIZATION_MARKER))
.build();
}
public static byte[] toByteArray(UUID uuid) {
ByteBuffer buffer = ByteBuffer.wrap(new byte[16]);
buffer.putLong(uuid.getMostSignificantBits());
buffer.putLong(uuid.getLeastSignificantBits());
return buffer.array();
}
private interface Clock {
long getTimeNanos();
}
}

View file

@ -0,0 +1,151 @@
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto2";
package signal;
option java_package = "org.signal.core.util.tracing";
option java_outer_classname = "TraceProtos";
/*
* Minimal interface needed to work with Perfetto.
*
* https://cs.android.com/android/platform/superproject/+/master:external/perfetto/protos/perfetto/trace/trace.proto
*/
message Trace {
repeated TracePacket packet = 1;
}
message TracePacket {
optional uint64 timestamp = 8;
optional uint32 timestamp_clock_id = 58;
oneof data {
TrackEvent track_event = 11;
TrackDescriptor track_descriptor = 60;
bytes synchronization_marker = 36;
}
oneof optional_trusted_packet_sequence_id {
uint32 trusted_packet_sequence_id = 10;
}
}
message TrackEvent {
repeated uint64 category_iids = 3;
repeated string categories = 22;
repeated DebugAnnotation debug_annotations = 4;
oneof name_field {
uint64 name_iid = 10;
string name = 23;
}
enum Type {
TYPE_UNSPECIFIED = 0;
TYPE_SLICE_BEGIN = 1;
TYPE_SLICE_END = 2;
TYPE_INSTANT = 3;
TYPE_COUNTER = 4;
}
optional Type type = 9;
optional uint64 track_uuid = 11;
optional int64 counter_value = 30;
oneof timestamp {
int64 timestamp_delta_us = 1;
int64 timestamp_absolute_us = 16;
}
oneof thread_time {
int64 thread_time_delta_us = 2;
int64 thread_time_absolute_us = 17;
}
}
message TrackDescriptor {
optional uint64 uuid = 1;
optional uint64 parent_uuid = 5;
optional string name = 2;
optional ThreadDescriptor thread = 4;
optional CounterDescriptor counter = 8;
}
message ThreadDescriptor {
optional int32 pid = 1;
optional int32 tid = 2;
optional string thread_name = 5;
}
message CounterDescriptor {
enum BuiltinCounterType {
COUNTER_UNSPECIFIED = 0;
COUNTER_THREAD_TIME_NS = 1;
COUNTER_THREAD_INSTRUCTION_COUNT = 2;
}
enum Unit {
UNIT_UNSPECIFIED = 0;
UNIT_TIME_NS = 1;
UNIT_COUNT = 2;
UNIT_SIZE_BYTES = 3;
}
optional BuiltinCounterType type = 1;
repeated string categories = 2;
optional Unit unit = 3;
optional int64 unit_multiplier = 4;
optional bool is_incremental = 5;
}
message DebugAnnotation {
message NestedValue {
enum NestedType {
UNSPECIFIED = 0;
DICT = 1;
ARRAY = 2;
}
optional NestedType nested_type = 1;
repeated string dict_keys = 2;
repeated NestedValue dict_values = 3;
repeated NestedValue array_values = 4;
optional int64 int_value = 5;
optional double double_value = 6;
optional bool bool_value = 7;
optional string string_value = 8;
}
oneof name_field {
uint64 name_iid = 1;
string name = 10;
}
oneof value {
bool bool_value = 2;
uint64 uint_value = 3;
int64 int_value = 4;
double double_value = 5;
string string_value = 6;
uint64 pointer_value = 7;
NestedValue nested_value = 8;
}
}

View file

@ -0,0 +1,9 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24">
<path
android:fillColor="#FFFFFFFF"
android:pathData="M22,4.5L16.35,4.5a4.45,4.45 0,0 0,-8.7 0L2,4.5L2,6L3.5,6L4.86,20A2.25,2.25 0,0 0,7.1 22h9.8a2.25,2.25 0,0 0,2.24 -2L20.5,6L22,6ZM12,2.5a3,3 0,0 1,2.82 2L9.18,4.5A3,3 0,0 1,12 2.5ZM17.65,19.83a0.76,0.76 0,0 1,-0.75 0.67L7.1,20.5a0.76,0.76 0,0 1,-0.75 -0.67L5,6L19,6ZM11.25,18L11.25,8h1.5L12.75,18ZM14.5,18L15,8h1.5L16,18ZM8,18 L7.5,8L9,8l0.5,10Z"/>
</vector>

View file

@ -0,0 +1,148 @@
package org.signal.core.util;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class BitmaskTest {
@Test
public void read_singleBit() {
assertFalse(Bitmask.read(0b00000000, 0));
assertFalse(Bitmask.read(0b11111101, 1));
assertFalse(Bitmask.read(0b11111011, 2));
assertFalse(Bitmask.read(0b01111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111L, 63));
assertTrue(Bitmask.read(0b00000001, 0));
assertTrue(Bitmask.read(0b00000010, 1));
assertTrue(Bitmask.read(0b00000100, 2));
assertTrue(Bitmask.read(0b10000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 63));
}
@Test
public void read_twoBits() {
assertEquals(0, Bitmask.read(0b11111100, 0, 2));
assertEquals(1, Bitmask.read(0b11111101, 0, 2));
assertEquals(2, Bitmask.read(0b11111110, 0, 2));
assertEquals(3, Bitmask.read(0b11111111, 0, 2));
assertEquals(0, Bitmask.read(0b11110011, 1, 2));
assertEquals(1, Bitmask.read(0b11110111, 1, 2));
assertEquals(2, Bitmask.read(0b11111011, 1, 2));
assertEquals(3, Bitmask.read(0b11111111, 1, 2));
assertEquals(0, Bitmask.read(0b00000000, 2, 2));
assertEquals(1, Bitmask.read(0b00010000, 2, 2));
assertEquals(2, Bitmask.read(0b00100000, 2, 2));
assertEquals(3, Bitmask.read(0b00110000, 2, 2));
assertEquals(0, Bitmask.read(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2));
assertEquals(1, Bitmask.read(0b01000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2));
assertEquals(2, Bitmask.read(0b10000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2));
assertEquals(3, Bitmask.read(0b11000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2));
}
@Test
public void read_fourBits() {
assertEquals(0, Bitmask.read(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 15, 4));
assertEquals(4, Bitmask.read(0b01000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 15, 4));
assertEquals(8, Bitmask.read(0b10000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 15, 4));
assertEquals(15, Bitmask.read(0b11110000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 15, 4));
}
@Test(expected = IllegalArgumentException.class)
public void read_error_negativeIndex() {
Bitmask.read(0b0000000, -1);
}
@Test(expected = IllegalArgumentException.class)
public void read_error_indexTooLarge_singleBit() {
Bitmask.read(0b0000000, 64);
}
@Test(expected = IllegalArgumentException.class)
public void read_error_indexTooLarge_twoBits() {
Bitmask.read(0b0000000, 32, 2);
}
@Test
public void update_singleBit() {
assertEquals(0b00000001, Bitmask.update(0b00000000, 0, true));
assertEquals(0b00000010, Bitmask.update(0b00000000, 1, true));
assertEquals(0b00000100, Bitmask.update(0b00000000, 2, true));
assertEquals(0b10000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 63, true));
assertEquals(0b11111110, Bitmask.update(0b11111111, 0, false));
assertEquals(0b11111101, Bitmask.update(0b11111111, 1, false));
assertEquals(0b11111011, Bitmask.update(0b11111111, 2, false));
assertEquals(0b01111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111L,
Bitmask.update(0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111L, 63, false));
assertEquals(0b11111111, Bitmask.update(0b11111111, 0, true));
assertEquals(0b11111111, Bitmask.update(0b11111111, 1, true));
assertEquals(0b11111111, Bitmask.update(0b11111111, 2, true));
assertEquals(0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111L,
Bitmask.update(0b11111111_11111111_11111111_11111111_11111111_11111111_11111111_11111111L, 63, true));
assertEquals(0b00000000, Bitmask.update(0b00000000, 0, false));
assertEquals(0b00000000, Bitmask.update(0b00000000, 1, false));
assertEquals(0b00000000, Bitmask.update(0b00000000, 2, false));
assertEquals(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 63, false));
}
@Test
public void update_twoBits() {
assertEquals(0b00000000, Bitmask.update(0b00000000, 0, 2, 0));
assertEquals(0b00000001, Bitmask.update(0b00000000, 0, 2, 1));
assertEquals(0b00000010, Bitmask.update(0b00000000, 0, 2, 2));
assertEquals(0b00000011, Bitmask.update(0b00000000, 0, 2, 3));
assertEquals(0b00000000, Bitmask.update(0b00000000, 1, 2, 0));
assertEquals(0b00000100, Bitmask.update(0b00000000, 1, 2, 1));
assertEquals(0b00001000, Bitmask.update(0b00000000, 1, 2, 2));
assertEquals(0b00001100, Bitmask.update(0b00000000, 1, 2, 3));
assertEquals(0b11111100, Bitmask.update(0b11111111, 0, 2, 0));
assertEquals(0b11111101, Bitmask.update(0b11111111, 0, 2, 1));
assertEquals(0b11111110, Bitmask.update(0b11111111, 0, 2, 2));
assertEquals(0b11111111, Bitmask.update(0b11111111, 0, 2, 3));
assertEquals(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2, 0));
assertEquals(0b01000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2, 1));
assertEquals(0b10000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2, 2));
assertEquals(0b11000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L,
Bitmask.update(0b00000000_00000000_00000000_00000000_00000000_00000000_00000000_00000000L, 31, 2, 3));
}
@Test(expected = IllegalArgumentException.class)
public void update_error_negativeIndex() {
Bitmask.update(0b0000000, -1, true);
}
@Test(expected = IllegalArgumentException.class)
public void update_error_indexTooLarge_singleBit() {
Bitmask.update(0b0000000, 64, true);
}
@Test(expected = IllegalArgumentException.class)
public void update_error_indexTooLarge_twoBits() {
Bitmask.update(0b0000000, 32, 2, 0);
}
@Test(expected = IllegalArgumentException.class)
public void update_error_negativeValue() {
Bitmask.update(0b0000000, 0, 2, -1);
}
@Test(expected = IllegalArgumentException.class)
public void update_error_valueTooLarge() {
Bitmask.update(0b0000000, 0, 2, 4);
}
}

View file

@ -0,0 +1,92 @@
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Test
class BreakIteratorCompatTest {
@Test
fun empty() {
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText("")
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun single() {
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText("a")
assertEquals(1, breakIterator.next())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun count_empty() {
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText("")
assertEquals(0, breakIterator.countBreaks())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun count_simple_text() {
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText("abc")
assertEquals(3, breakIterator.countBreaks())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun two_counts() {
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText("abc")
assertEquals(3, breakIterator.countBreaks())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
assertEquals(3, breakIterator.countBreaks())
}
@Test
fun count_multi_character_graphemes() {
val hindi = "समाजो गयेग"
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText(hindi)
assertEquals(7, breakIterator.countBreaks())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun iterate_multi_character_graphemes() {
val hindi = "समाजो गयेग"
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText(hindi)
assertEquals(listOf("", "मा", "जो", " ", "", "ये", ""), breakIterator.toList())
assertEquals(BreakIteratorCompat.DONE, breakIterator.next())
}
@Test
fun split_multi_character_graphemes() {
val hindi = "समाजो गयेग"
val breakIterator = BreakIteratorCompat.getInstance()
breakIterator.setText(hindi)
assertEquals("समाजो गयेग", breakIterator.take(8))
assertEquals("समाजो गयेग", breakIterator.take(7))
assertEquals("समाजो गये", breakIterator.take(6))
assertEquals("समाजो ग", breakIterator.take(5))
assertEquals("समाजो ", breakIterator.take(4))
assertEquals("समाजो", breakIterator.take(3))
assertEquals("समा", breakIterator.take(2))
assertEquals("", breakIterator.take(1))
assertEquals("", breakIterator.take(0))
assertEquals("", breakIterator.take(-1))
}
}

View file

@ -0,0 +1,339 @@
package org.signal.core.util
import android.app.Application
import android.text.SpannedString
import android.widget.TextView
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isGreaterThan
import assertk.assertions.isLessThan
import assertk.assertions.isLessThanOrEqualTo
import assertk.assertions.isNull
import okio.utf8Size
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.RuntimeEnvironment
import org.robolectric.annotation.Config
@RunWith(RobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class ByteLimitInputFilterTest {
@Test
fun `filter - null source, returns null`() {
val filter = ByteLimitInputFilter(10)
val result = filter.filter(null, 0, 0, SpannedString(""), 0, 0)
assertThat(result).isNull()
}
@Test
fun `filter - null dest, returns null`() {
val filter = ByteLimitInputFilter(10)
val result = filter.filter("test", 0, 4, null, 0, 0)
assertThat(result).isNull()
}
@Test
fun `filter - within byte limit, returns null`() {
val filter = ByteLimitInputFilter(10)
val existingText = SpannedString("hi")
val insertText = "test"
val result = filter.testAppend(insertText, existingText)
assertThat(result).isNull()
}
@Test
fun `filter - exact byte limit, returns null`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hi")
val insertText = "test"
val result = filter.testAppend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - exceeds byte limit, returns truncated`() {
val filter = ByteLimitInputFilter(5)
val dest = SpannedString("hi")
val insertText = "test"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("tes")
}
@Test
fun `filter - no space available, returns empty`() {
val filter = ByteLimitInputFilter(2)
val dest = SpannedString("hi")
val insertText = "test"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("")
}
@Test
fun `filter - insert at beginning`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hi")
val insertText = "test"
val result = filter.testPrepend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - insert at end`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hi")
val insertText = "test"
val result = filter.testAppend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - replace text`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hello")
val insertText = "test"
val result = filter.testReplaceRange(insertText, dest, 1, 4)
assertThat(result).isNull()
}
@Test
fun `filter - unicode characters`() {
val filter = ByteLimitInputFilter(9)
val dest = SpannedString("hi")
val insertText = "café"
val result = filter.testAppend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - emoji characters`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hi")
val insertText = "😀😁"
assertThat((insertText + dest).utf8Size()).isGreaterThan(6)
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("😀")
}
@Test
fun `filter - mixed unicode and emoji`() {
val filter = ByteLimitInputFilter(15)
val dest = SpannedString("test")
val insertText = "café😀"
val result = filter.testAppend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - partial source range`() {
val filter = ByteLimitInputFilter(5)
val dest = SpannedString("hi")
val source = "abcdef"
val result = filter.testPartialSource(source, 1, 4, dest, dest.length)
assertThat(result).isNull()
}
@Test
fun `filter - long text truncation`() {
val filter = ByteLimitInputFilter(10)
val dest = SpannedString("")
val longText = "this is a very long text that should be truncated"
val result = filter.testAppend(longText, dest)
assertThat(result.toString()).isEqualTo("this is a ")
}
@Test
fun `filter - ascii characters`() {
val filter = ByteLimitInputFilter(5)
val dest = SpannedString("")
val insertText = "hello"
val result = filter.testAppend(insertText, dest)
assertThat(result).isNull()
}
@Test
fun `filter - surrogate handling`() {
val filter = ByteLimitInputFilter(8)
val dest = SpannedString("hi")
val insertText = "🎉🎊"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("🎉")
}
@Test
fun `filter - empty source`() {
val filter = ByteLimitInputFilter(10)
val dest = SpannedString("test")
val insertText = ""
val result = filter.testInsertAt(insertText, dest, 2)
assertThat(result).isNull()
}
@Test
fun `filter - empty dest`() {
val filter = ByteLimitInputFilter(3)
val dest = SpannedString("")
val insertText = "test"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("tes")
}
@Test
fun `filter - unicode truncation`() {
val filter = ByteLimitInputFilter(4)
val dest = SpannedString("")
val insertText = "café"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("caf")
}
@Test
fun `filter - emoji truncation`() {
val filter = ByteLimitInputFilter(4)
val dest = SpannedString("")
val insertText = "😀a"
val result = filter.testAppend(insertText, dest)
assertThat(result.toString()).isEqualTo("😀")
}
@Test
fun `filter - insert at middle`() {
val filter = ByteLimitInputFilter(7)
val dest = SpannedString("hello")
val insertText = "XY"
val result = filter.testInsertAt(insertText, dest, 2)
assertThat(result).isNull()
}
@Test
fun `filter - insert at middle with truncation`() {
val filter = ByteLimitInputFilter(6)
val dest = SpannedString("hello")
val insertText = "XYZ"
val result = filter.testInsertAt(insertText, dest, 2)
assertThat(result.toString()).isEqualTo("X")
}
@Test
fun `textView integration - append within limit`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(10))
textView.setText("hi", TextView.BufferType.EDITABLE)
textView.append("test")
assertThat(textView.text.toString()).isEqualTo("hitest")
}
@Test
fun `textView integration - append exceeds limit`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(5))
textView.setText("hi", TextView.BufferType.EDITABLE)
textView.append("test")
assertThat(textView.text.toString()).isEqualTo("hites")
}
@Test
fun `textView integration - replace text with truncation`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(8))
textView.setText("hello", TextView.BufferType.EDITABLE)
val editable = textView.editableText
editable.replace(3, 5, "test")
assertThat(textView.text.toString()).isEqualTo("heltest")
}
@Test
fun `textView integration - emoji handling`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(10))
textView.setText("hi", TextView.BufferType.EDITABLE)
textView.append("😀😁")
assertThat(textView.text.toString().utf8Size()).isEqualTo(10)
}
@Test
fun `textView integration - unicode characters`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(10))
textView.setText("hi", TextView.BufferType.EDITABLE)
textView.append("café")
assertThat(textView.text.toString()).isEqualTo("hicafé")
}
@Test
fun `textView integration - set text directly`() {
val textView = TextView(RuntimeEnvironment.getApplication())
textView.filters = arrayOf(ByteLimitInputFilter(5))
textView.setText("this is a long text", TextView.BufferType.EDITABLE)
assertThat(textView.text.toString()).isEqualTo("this ")
}
@Test
fun `textView integration - fuzzing with mixed character types`() {
val textView = TextView(RuntimeEnvironment.getApplication())
val byteLimit = 100
textView.filters = arrayOf(ByteLimitInputFilter(byteLimit))
val asciiChars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()_+-=[]{}|;:,.<>?"
val unicodeChars = "àáâãäåæçèéêëìíîïñòóôõöøùúûüýÿ"
val emojiChars = "😀😁😂😃😄😅😆😇😈😉😊😋😌😍😎😏😐😑😒😓😔😕😖😗😘😙😚😛😜😝😞😟😠😡😢😣😤😥😦😧😨😩😪😫😬😭😮😯😰😱😲😳😴😵😶😷😸😹😺😻😼😽😾😿🙀🙁🙂"
val japaneseChars = "あいうえおかきくけこさしすせそたちつてとなにぬねのはひふへほまみむめもやゆよらりるれろわをんアイウエオカキクケコサシスセソタチツテトナニヌネノハヒフヘホマミムメモヤユヨラリルレロワヲン日本語漢字平仮名片仮名"
val allChars = asciiChars + unicodeChars + emojiChars + japaneseChars
repeat(100) { iteration ->
textView.setText("", TextView.BufferType.EDITABLE)
val targetLength = 150 + (iteration * 5)
val randomText = StringBuilder().apply {
repeat(targetLength) {
append(allChars.random())
}
}
textView.setText(randomText.toString(), TextView.BufferType.EDITABLE)
val finalText = textView.text.toString()
val actualByteSize = finalText.utf8Size()
assertThat(actualByteSize).isLessThanOrEqualTo((byteLimit).toLong())
if (randomText.toString().utf8Size() > byteLimit) {
assertThat(finalText.length).isLessThan(randomText.length)
}
}
}
private fun ByteLimitInputFilter.testAppend(insertText: String, dest: SpannedString): CharSequence? {
return this.filter(insertText, 0, insertText.length, dest, dest.length, dest.length)
}
private fun ByteLimitInputFilter.testPrepend(insertText: String, dest: SpannedString): CharSequence? {
return this.filter(insertText, 0, insertText.length, dest, 0, 0)
}
private fun ByteLimitInputFilter.testInsertAt(insertText: String, dest: SpannedString, position: Int): CharSequence? {
return this.filter(insertText, 0, insertText.length, dest, position, position)
}
private fun ByteLimitInputFilter.testReplaceRange(insertText: String, dest: SpannedString, startPos: Int, endPos: Int): CharSequence? {
return this.filter(insertText, 0, insertText.length, dest, startPos, endPos)
}
private fun ByteLimitInputFilter.testPartialSource(source: String, startPos: Int, endPos: Int, dest: SpannedString, insertPos: Int): CharSequence? {
return this.filter(source, startPos, endPos, dest, insertPos, insertPos)
}
}

View file

@ -0,0 +1,34 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import androidx.sqlite.db.SupportSQLiteDatabase
import androidx.sqlite.db.SupportSQLiteOpenHelper
import androidx.sqlite.db.framework.FrameworkSQLiteOpenHelperFactory
import androidx.test.core.app.ApplicationProvider
/**
* Helper to create an in-memory database used for testing SQLite stuff.
*/
object InMemorySqliteOpenHelper {
fun create(
onCreate: (db: SupportSQLiteDatabase) -> Unit,
onUpgrade: (db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) -> Unit = { _, _, _ -> }
): SupportSQLiteOpenHelper {
val configuration = SupportSQLiteOpenHelper.Configuration(
context = ApplicationProvider.getApplicationContext(),
name = "test",
callback = object : SupportSQLiteOpenHelper.Callback(1) {
override fun onCreate(db: SupportSQLiteDatabase) = onCreate(db)
override fun onUpgrade(db: SupportSQLiteDatabase, oldVersion: Int, newVersion: Int) = onUpgrade(db, oldVersion, newVersion)
},
useNoBackupDirectory = false,
allowDataLossOnRecovery = true
)
return FrameworkSQLiteOpenHelperFactory().create(configuration)
}
}

View file

@ -0,0 +1,38 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Test
import kotlin.random.Random
class InputStreamExtensionTests {
@Test
fun `when I call readLength, it returns the correct length`() {
for (i in 1..10) {
val bytes = ByteArray(Random.nextInt(from = 512, until = 8092))
val length = bytes.inputStream().readLength()
assertEquals(bytes.size.toLong(), length)
}
}
@Test
fun `when I call readAtMostNBytes, I only read that many bytes`() {
val bytes = ByteArray(100)
val inputStream = bytes.inputStream()
val readBytes = inputStream.readAtMostNBytes(50)
assertEquals(50, readBytes.size)
}
@Test
fun `when I call readAtMostNBytes, it will return at most the length of the stream`() {
val bytes = ByteArray(100)
val inputStream = bytes.inputStream()
val readBytes = inputStream.readAtMostNBytes(200)
assertEquals(100, readBytes.size)
}
}

View file

@ -0,0 +1,47 @@
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Test
class ListUtilTest {
@Test
fun chunk_oneChunk() {
val input = listOf("A", "B", "C")
var output = ListUtil.chunk(input, 3)
assertEquals(1, output.size)
assertEquals(input, output[0])
output = ListUtil.chunk(input, 4)
assertEquals(1, output.size)
assertEquals(input, output[0])
output = ListUtil.chunk(input, 100)
assertEquals(1, output.size)
assertEquals(input, output[0])
}
@Test
fun chunk_multipleChunks() {
val input: List<String> = listOf("A", "B", "C", "D", "E")
var output = ListUtil.chunk(input, 4)
assertEquals(2, output.size)
assertEquals(listOf("A", "B", "C", "D"), output[0])
assertEquals(listOf("E"), output[1])
output = ListUtil.chunk(input, 2)
assertEquals(3, output.size)
assertEquals(listOf("A", "B"), output[0])
assertEquals(listOf("C", "D"), output[1])
assertEquals(listOf("E"), output[2])
output = ListUtil.chunk(input, 1)
assertEquals(5, output.size)
assertEquals(listOf("A"), output[0])
assertEquals(listOf("B"), output[1])
assertEquals(listOf("C"), output[2])
assertEquals(listOf("D"), output[3])
assertEquals(listOf("E"), output[4])
}
}

View file

@ -0,0 +1,124 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import android.app.Application
import androidx.sqlite.db.SupportSQLiteOpenHelper
import assertk.assertThat
import assertk.assertions.isEqualTo
import assertk.assertions.isNotNull
import org.junit.After
import org.junit.Assert.assertArrayEquals
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.annotation.Config
@RunWith(RobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class SQLiteDatabaseExtensionsTest {
lateinit var db: SupportSQLiteOpenHelper
companion object {
const val TABLE_NAME = "test"
const val ID = "_id"
const val STRING_COLUMN = "string_column"
const val LONG_COLUMN = "long_column"
const val DOUBLE_COLUMN = "double_column"
const val BLOB_COLUMN = "blob_column"
}
@Before
fun setup() {
db = InMemorySqliteOpenHelper.create(
onCreate = { db ->
db.execSQL("CREATE TABLE $TABLE_NAME ($ID INTEGER PRIMARY KEY AUTOINCREMENT, $STRING_COLUMN TEXT, $LONG_COLUMN INTEGER, $DOUBLE_COLUMN DOUBLE, $BLOB_COLUMN BLOB)")
}
)
db.writableDatabase.insertInto(TABLE_NAME)
.values(
STRING_COLUMN to "asdf",
LONG_COLUMN to 1,
DOUBLE_COLUMN to 0.5f,
BLOB_COLUMN to byteArrayOf(1, 2, 3)
)
.run()
}
@After
fun cleanUp() {
db.close()
}
@Test
fun `update - content values work`() {
val updateCount: Int = db.writableDatabase
.update("test")
.values(
STRING_COLUMN to "asdf2",
LONG_COLUMN to 2,
DOUBLE_COLUMN to 1.5f,
BLOB_COLUMN to byteArrayOf(4, 5, 6)
)
.where("$ID = ?", 1)
.run()
val record = readRecord(1)
assertThat(updateCount).isEqualTo(1)
assertThat(record).isNotNull()
assertThat(record!!.id).isEqualTo(1)
assertThat(record.stringColumn).isEqualTo("asdf2")
assertThat(record.longColumn).isEqualTo(2)
assertThat(record.doubleColumn).isEqualTo(1.5f)
assertArrayEquals(record.blobColumn, byteArrayOf(4, 5, 6))
}
@Test
fun `update - querying by blob works`() {
val updateCount: Int = db.writableDatabase
.update("test")
.values(
STRING_COLUMN to "asdf2"
)
.where("$BLOB_COLUMN = ?", byteArrayOf(1, 2, 3))
.run()
val record = readRecord(1)
assertThat(updateCount).isEqualTo(1)
assertThat(record).isNotNull()
assertThat(record!!.stringColumn).isEqualTo("asdf2")
}
private fun readRecord(id: Long): TestRecord? {
return db.readableDatabase
.select()
.from(TABLE_NAME)
.where("$ID = ?", id)
.run()
.readToSingleObject {
TestRecord(
id = it.requireLong(ID),
stringColumn = it.requireString(STRING_COLUMN),
longColumn = it.requireLong(LONG_COLUMN),
doubleColumn = it.requireFloat(DOUBLE_COLUMN),
blobColumn = it.requireBlob(BLOB_COLUMN)
)
}
}
class TestRecord(
val id: Long,
val stringColumn: String?,
val longColumn: Long,
val doubleColumn: Float,
val blobColumn: ByteArray?
)
}

View file

@ -0,0 +1,349 @@
package org.signal.core.util;
import android.app.Application;
import android.content.ContentValues;
import androidx.annotation.NonNull;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@RunWith(RobolectricTestRunner.class)
@Config(manifest = Config.NONE, application = Application.class)
public final class SqlUtilTest {
@Test
public void buildTrueUpdateQuery_simple() {
String selection = "_id = ?";
String[] args = new String[]{"1"};
ContentValues values = new ContentValues();
values.put("a", 2);
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ?) AND (a != ? OR a IS NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1", "2" }, updateQuery.getWhereArgs());
}
@Test
public void buildTrueUpdateQuery_complexSelection() {
String selection = "_id = ? AND (foo = ? OR bar != ?)";
String[] args = new String[]{"1", "2", "3"};
ContentValues values = new ContentValues();
values.put("a", 4);
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ? AND (foo = ? OR bar != ?)) AND (a != ? OR a IS NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, updateQuery.getWhereArgs());
}
@Test
public void buildTrueUpdateQuery_multipleContentValues() {
String selection = "_id = ?";
String[] args = new String[]{"1"};
ContentValues values = new ContentValues();
values.put("a", 2);
values.put("b", 3);
values.put("c", 4);
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ?) AND (a != ? OR a IS NULL OR b != ? OR b IS NULL OR c != ? OR c IS NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4"}, updateQuery.getWhereArgs());
}
@Test
public void buildTrueUpdateQuery_nullContentValue() {
String selection = "_id = ?";
String[] args = new String[]{"1"};
ContentValues values = new ContentValues();
values.put("a", (String) null);
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ?) AND (a NOT NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1" }, updateQuery.getWhereArgs());
}
@Test
public void buildTrueUpdateQuery_complexContentValue() {
String selection = "_id = ?";
String[] args = new String[]{"1"};
ContentValues values = new ContentValues();
values.put("a", (String) null);
values.put("b", 2);
values.put("c", 3);
values.put("d", (String) null);
values.put("e", (String) null);
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ?) AND (a NOT NULL OR b != ? OR b IS NULL OR c != ? OR c IS NULL OR d NOT NULL OR e NOT NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1", "2", "3" }, updateQuery.getWhereArgs());
}
@Test
public void buildTrueUpdateQuery_blobComplex() {
String selection = "_id = ?";
String[] args = new String[]{"1"};
ContentValues values = new ContentValues();
values.put("a", hexToBytes("FF"));
values.put("b", 2);
values.putNull("c");
SqlUtil.Query updateQuery = SqlUtil.buildTrueUpdateQuery(selection, args, values);
assertEquals("(_id = ?) AND (hex(a) != ? OR a IS NULL OR b != ? OR b IS NULL OR c NOT NULL)", updateQuery.getWhere());
assertArrayEquals(new String[] { "1", "FF", "2" }, updateQuery.getWhereArgs());
}
@Test
public void buildCollectionQuery_single() {
List<SqlUtil.Query> updateQuery = SqlUtil.buildCollectionQuery("a", Arrays.asList(1));
assertEquals(1, updateQuery.size());
assertEquals("a IN (?)", updateQuery.get(0).getWhere());
assertArrayEquals(new String[] { "1" }, updateQuery.get(0).getWhereArgs());
}
@Test
public void buildCollectionQuery_single_withPrefix() {
List<SqlUtil.Query> updateQuery = SqlUtil.buildCollectionQuery("a", Arrays.asList(1), "b = 1 AND");
assertEquals(1, updateQuery.size());
assertEquals("b = 1 AND a IN (?)", updateQuery.get(0).getWhere());
assertArrayEquals(new String[] { "1" }, updateQuery.get(0).getWhereArgs());
}
@Test
public void buildCollectionQuery_multiple() {
List<SqlUtil.Query> updateQuery = SqlUtil.buildCollectionQuery("a", Arrays.asList(1, 2, 3));
assertEquals(1, updateQuery.size());
assertEquals("a IN (?, ?, ?)", updateQuery.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3" }, updateQuery.get(0).getWhereArgs());
}
@Test
public void buildCollectionQuery_multiple_twoBatches() {
List<SqlUtil.Query> updateQuery = SqlUtil.buildCollectionQuery("a", Arrays.asList(1, 2, 3), "", 2);
assertEquals(2, updateQuery.size());
assertEquals("a IN (?, ?)", updateQuery.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, updateQuery.get(0).getWhereArgs());
assertEquals("a IN (?)", updateQuery.get(1).getWhere());
assertArrayEquals(new String[] { "3" }, updateQuery.get(1).getWhereArgs());
}
@Test
public void buildCollectionQuery_multipleRecipientIds() {
List<SqlUtil.Query> updateQuery = SqlUtil.buildCollectionQuery("a", Arrays.asList(new TestId(1), new TestId(2), new TestId(3)));
assertEquals(1, updateQuery.size());
assertEquals("a IN (?, ?, ?)", updateQuery.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3" }, updateQuery.get(0).getWhereArgs());
}
public void buildCollectionQuery_none() {
List<SqlUtil.Query> results = SqlUtil.buildCollectionQuery("a", Collections.emptyList());
assertTrue(results.isEmpty());
}
@Test
public void buildFastCollectionQuery_single() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\"]" }, updateQuery.getWhereArgs());
}
@Test
public void buildFastCollectionQuery_multiple() {
SqlUtil.Query updateQuery = SqlUtil.buildFastCollectionQuery("a", Arrays.asList(1, 2, 3));
assertEquals("a IN (SELECT e.value FROM json_each(?) e)", updateQuery.getWhere());
assertArrayEquals(new String[] { "[\"1\",\"2\",\"3\"]" }, updateQuery.getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_single_singleBatch() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args);
assertEquals(1, queries.size());
assertEquals("(a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, queries.get(0).getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_multiple_singleBatch() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
args.add(SqlUtil.buildArgs(3, 4));
args.add(SqlUtil.buildArgs(5, 6));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args);
assertEquals(1, queries.size());
assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4", "5", "6" }, queries.get(0).getWhereArgs());
}
@Test
public void buildCustomCollectionQuery_twoBatches() {
List<String[]> args = new ArrayList<>();
args.add(SqlUtil.buildArgs(1, 2));
args.add(SqlUtil.buildArgs(3, 4));
args.add(SqlUtil.buildArgs(5, 6));
List<SqlUtil.Query> queries = SqlUtil.buildCustomCollectionQuery("a = ? AND b = ?", args, 4);
assertEquals(2, queries.size());
assertEquals("(a = ? AND b = ?) OR (a = ? AND b = ?)", queries.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, queries.get(0).getWhereArgs());
assertEquals("(a = ? AND b = ?)", queries.get(1).getWhere());
assertArrayEquals(new String[] { "5", "6" }, queries.get(1).getWhereArgs());
}
@Test
public void buildBulkInsert_single_singleBatch() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
contentValues.add(cv1);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues, null);
assertEquals(1, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, output.get(0).getWhereArgs());
}
@Test
public void buildBulkInsert_single_singleBatch_containsNulls() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
cv1.put("c", (String) null);
contentValues.add(cv1);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b", "c"}, contentValues, null);
assertEquals(1, output.size());
assertEquals("INSERT INTO mytable (a, b, c) VALUES (?, ?, null)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2" }, output.get(0).getWhereArgs());
}
@Test
public void buildBulkInsert_multiple_singleBatch() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
ContentValues cv2 = new ContentValues();
cv2.put("a", 3);
cv2.put("b", 4);
contentValues.add(cv1);
contentValues.add(cv2);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues, null);
assertEquals(1, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs());
}
@Test
public void buildBulkInsert_twoBatches() {
List<ContentValues> contentValues = new ArrayList<>();
ContentValues cv1 = new ContentValues();
cv1.put("a", 1);
cv1.put("b", 2);
ContentValues cv2 = new ContentValues();
cv2.put("a", 3);
cv2.put("b", 4);
ContentValues cv3 = new ContentValues();
cv3.put("a", 5);
cv3.put("b", 6);
contentValues.add(cv1);
contentValues.add(cv2);
contentValues.add(cv3);
List<SqlUtil.Query> output = SqlUtil.buildBulkInsert("mytable", new String[] { "a", "b"}, contentValues, 4, null);
assertEquals(2, output.size());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?), (?, ?)", output.get(0).getWhere());
assertArrayEquals(new String[] { "1", "2", "3", "4" }, output.get(0).getWhereArgs());
assertEquals("INSERT INTO mytable (a, b) VALUES (?, ?)", output.get(1).getWhere());
assertArrayEquals(new String[] { "5", "6" }, output.get(1).getWhereArgs());
}
@Test
public void aggregateQueries() {
SqlUtil.Query q1 = SqlUtil.buildQuery("a = ?", 1);
SqlUtil.Query q2 = SqlUtil.buildQuery("b = ?", 2);
SqlUtil.Query q3 = q1.and(q2);
assertEquals("(a = ?) AND (b = ?)", q3.getWhere());
assertArrayEquals(new String[]{"1", "2"}, q3.getWhereArgs());
}
private static byte[] hexToBytes(String hex) {
try {
return Hex.fromStringCondensed(hex);
} catch (IOException e) {
throw new AssertionError(e);
}
}
private static class TestId implements DatabaseId {
private final long id;
private TestId(long id) {
this.id = id;
}
@Override
public @NonNull String serialize() {
return String.valueOf(id);
}
}
}

View file

@ -0,0 +1,80 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import okio.utf8Size
import org.junit.Assert.assertTrue
import org.junit.Test
class StringExtensionsTest {
@Test
fun `splitByByteLength - fuzzing`() {
val characterSet = "日月木山川田水火金土空海花風雨雪星森犬猫鳥魚虫人子女男友学校車電話本書時分先生愛夢楽音話語映画新古長短高低東西南北春夏秋冬雨雲星夜朝昼電気手足目耳口心頭体家国町村道橋山川本店仕事時間会話思考知識感情自動車飛行機船馬牛羊豚鶏鳥猫犬虎龍"
for (stringSize in 2100..2500) {
for (byteLimit in 2000..2500) {
val builder = StringBuilder()
repeat(stringSize) {
builder.append(characterSet.random())
}
val (trimmed, _) = builder.toString().splitByByteLength(byteLimit)
assertTrue(trimmed.utf8Size() <= byteLimit)
}
}
}
@Test
fun `splitByByteLength - long string`() {
val myString = """
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
すべての人間は生まれながらにして自由であり尊厳と権利において平等である彼らは理性と良心を授けられており互いに兄弟愛の精神をもって行動しなければならない
一方現代社会において技術の進歩は人々の生活を大きく変えつつある情報の流通は瞬時に行われ距離や時間の壁を越えて人々がつながる世界が現実となっているしかしこれに伴い新たな課題も生じており個人のプライバシーや倫理の問題が議論の中心となっている
経済の発展と共に都市化が進む一方で自然環境の破壊も深刻な問題となっている持続可能な社会を目指すためには私たち一人ひとりが責任を持ち資源を大切にする意識を持つことが重要である
文化や伝統は社会の根底を支えるものであり時代が変わってもその価値は変わらない古くから伝わる知恵や習慣は現代においても新たな意味を持ち続けるだろう
""".trimIndent()
val (trimmed, _) = myString.splitByByteLength(2048)
assertTrue(trimmed.utf8Size() <= 2048)
}
}

View file

@ -0,0 +1,40 @@
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@Suppress("ClassName")
@RunWith(Parameterized::class)
class StringExtensions_asListContains(
private val model: String,
private val serializedList: String,
private val expected: Boolean
) {
@Test
fun testModelInList() {
val actual = serializedList.asListContains(model)
assertEquals(expected, actual)
}
companion object {
@JvmStatic
@Parameterized.Parameters(name = "{index}: modelInList(model={0}, list={1})={2}")
fun data(): List<Array<Any>> {
return listOf<Array<Any>>(
arrayOf("a", "a", true),
arrayOf("a", "a,b", true),
arrayOf("a", "c,a,b", true),
arrayOf("ab", "a*", true),
arrayOf("ab", "c,a*,b", true),
arrayOf("abc", "c,ab*,b", true),
arrayOf("a", "b", false),
arrayOf("a", "abc", false),
arrayOf("b", "a*", false)
).toList()
}
}
}

View file

@ -0,0 +1,42 @@
package org.signal.core.util
import okio.utf8Size
import org.junit.Assert.assertEquals
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@Suppress("ClassName")
@RunWith(Parameterized::class)
class StringExtensions_splitByByteLength(
private val testInput: String,
private val byteLength: Int,
private val expected: Pair<String, String?>
) {
@Test
fun testModelInList() {
val actual = testInput.splitByByteLength(byteLength)
assertEquals(expected, actual)
assertTrue(actual.first.utf8Size() <= byteLength)
}
companion object {
@JvmStatic
@Parameterized.Parameters(name = "{index}: splitByByteLength(input={0}, byteLength={1})")
fun data(): List<Array<Any>> {
return listOf<Array<Any>>(
arrayOf("1234567890", 0, "" to "1234567890"),
arrayOf("1234567890", 3, "123" to "4567890"),
arrayOf("1234567890", 10, "1234567890" to null),
arrayOf("1234567890", 15, "1234567890" to null),
arrayOf("大いなる力には大いなる責任が伴う", 0, "" to "大いなる力には大いなる責任が伴う"),
arrayOf("大いなる力には大いなる責任が伴う", 8, "大い" to "なる力には大いなる責任が伴う"),
arrayOf("大いなる力には大いなる責任が伴う", 47, "大いなる力には大いなる責任が伴" to ""),
arrayOf("大いなる力には大いなる責任が伴う", 48, "大いなる力には大いなる責任が伴う" to null),
arrayOf("大いなる力には大いなる責任が伴う", 100, "大いなる力には大いなる責任が伴う" to null)
).toList()
}
}
}

View file

@ -0,0 +1,57 @@
package org.signal.core.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
@RunWith(Parameterized.class)
public final class StringUtilTest_abbreviateInMiddle {
@Parameterized.Parameter(0)
public CharSequence input;
@Parameterized.Parameter(1)
public int maxChars;
@Parameterized.Parameter(2)
public CharSequence expected;
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
{null, 0, null},
{null, 1, null},
{"", 0, ""},
{"", 1, ""},
{"0123456789", 10, "0123456789"},
{"0123456789", 11, "0123456789"},
{"0123456789", 9, "0123…6789"},
{"0123456789", 8, "012…6789"},
{"0123456789", 7, "012…789"},
{"0123456789", 6, "01…789"},
{"0123456789", 5, "01…89"},
{"0123456789", 4, "0…89"},
{"0123456789", 3, "0…9"},
});
}
@Test
public void abbreviateInMiddle() {
CharSequence output = StringUtil.abbreviateInMiddle(input, maxChars);
assertEquals(expected, output);
if (Objects.equals(input, output)) {
assertSame(output, input);
} else {
assertNotNull(output);
assertEquals(maxChars, output.length());
}
}
}

View file

@ -0,0 +1,54 @@
@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
package org.signal.core.util
import android.app.Application
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.ParameterizedRobolectricTestRunner
import org.robolectric.ParameterizedRobolectricTestRunner.Parameter
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters
import org.robolectric.annotation.Config
import java.lang.Boolean as JavaBoolean
@Suppress("ClassName")
@RunWith(value = ParameterizedRobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class StringUtilTest_endsWith {
@Parameter(0)
lateinit var text: CharSequence
@Parameter(1)
lateinit var substring: CharSequence
@Parameter(2)
lateinit var expected: JavaBoolean
companion object {
@JvmStatic
@Parameters
fun data(): Collection<Array<Any>> {
return listOf(
arrayOf("Text", "xt", true),
arrayOf("Text", "", true),
arrayOf("Text", "XT", false),
arrayOf("Text…", "xt…", true),
arrayOf("", "Te", false),
arrayOf("Text", "Text", true),
arrayOf("Text", "2Text", false),
arrayOf("\uD83D\uDC64Text", "Te", false),
arrayOf("Text text text\uD83D\uDC64", "\uD83D\uDC64", true),
arrayOf("Text\uD83D\uDC64Text", "\uD83D\uDC64Text", true)
)
}
}
@Test
fun replace() {
val result = StringUtil.endsWith(text, substring)
assertEquals(expected, result)
}
}

View file

@ -0,0 +1,54 @@
package org.signal.core.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public final class StringUtilTest_hasMixedTextDirection {
private final CharSequence input;
private final boolean expected;
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
{ "", false },
{ null, false },
{ "A", false},
{ "A.", false},
{ "'A'", false},
{ "A,", false},
{ "ة", false}, // Arabic
{ "", false}, // Arabic
{ "ی", false}, // Kurdish
{ "ی", false }, // Farsi
{ "و", false }, // Urdu
{ "ת", false }, // Hebrew
{ "ש", false }, // Yiddish
{ "", true }, // Arabic-ASCII
{ "A.ة", true }, // Arabic-ASCII
{ "یA", true }, // Kurdish-ASCII
{ "", true }, // Farsi-ASCII
{ "وA", true }, // Urdu-ASCII
{ "", true }, // Hebrew-ASCII
{ "שA", true }, // Yiddish-ASCII
});
}
public StringUtilTest_hasMixedTextDirection(CharSequence input, boolean expected) {
this.input = input;
this.expected = expected;
}
@Test
public void trim() {
boolean output = BidiUtil.hasMixedTextDirection(input);
assertEquals(expected, output);
}
}

View file

@ -0,0 +1,51 @@
package org.signal.core.util
import android.app.Application
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.ParameterizedRobolectricTestRunner
import org.robolectric.ParameterizedRobolectricTestRunner.Parameter
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters
import org.robolectric.annotation.Config
@Suppress("ClassName")
@RunWith(value = ParameterizedRobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class StringUtilTest_replace {
@Parameter(0)
lateinit var text: CharSequence
@Parameter(1)
lateinit var charToReplace: Character
@Parameter(2)
lateinit var replacement: String
@Parameter(3)
lateinit var expected: CharSequence
companion object {
@JvmStatic
@Parameters
fun data(): Collection<Array<Any>> {
return listOf(
arrayOf("Replace\nme", '\n', " ", "Replace me"),
arrayOf("Replace me", '\n', " ", "Replace me"),
arrayOf("\nReplace me", '\n', " ", " Replace me"),
arrayOf("Replace me\n", '\n', " ", "Replace me "),
arrayOf("Replace\n\nme", '\n', " ", "Replace me"),
arrayOf("Replace\nme\n", '\n', " ", "Replace me "),
arrayOf("\n\nReplace\n\nme\n", '\n', " ", " Replace me ")
)
}
}
@Test
fun replace() {
val result = StringUtil.replace(text, charToReplace.charValue(), replacement)
assertEquals(expected.toString(), result.toString())
}
}

View file

@ -0,0 +1,54 @@
@file:Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
package org.signal.core.util
import android.app.Application
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.ParameterizedRobolectricTestRunner
import org.robolectric.ParameterizedRobolectricTestRunner.Parameter
import org.robolectric.ParameterizedRobolectricTestRunner.Parameters
import org.robolectric.annotation.Config
import java.lang.Boolean as JavaBoolean
@Suppress("ClassName")
@RunWith(value = ParameterizedRobolectricTestRunner::class)
@Config(manifest = Config.NONE, application = Application::class)
class StringUtilTest_startsWith {
@Parameter(0)
lateinit var text: CharSequence
@Parameter(1)
lateinit var substring: CharSequence
@Parameter(2)
lateinit var expected: JavaBoolean
companion object {
@JvmStatic
@Parameters
fun data(): Collection<Array<Any>> {
return listOf(
arrayOf("Text", "Te", true),
arrayOf("Text", "", true),
arrayOf("Text", "te", false),
arrayOf("…Text", "…Te", true),
arrayOf("", "Te", false),
arrayOf("Text", "Text", true),
arrayOf("Text", "Text2", false),
arrayOf("\uD83D\uDC64Text", "Te", false),
arrayOf("Text text text\uD83D\uDC64", "\uD83D\uDC64", false),
arrayOf("\uD83D\uDC64Text", "\uD83D\uDC64Te", true)
)
}
}
@Test
fun replace() {
val result = StringUtil.startsWith(text, substring)
assertEquals(expected, result)
}
}

View file

@ -0,0 +1,59 @@
package org.signal.core.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;
@RunWith(Parameterized.class)
public final class StringUtilTest_trim {
private final CharSequence input;
private final CharSequence expected;
private final boolean changed;
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
{ "", "", false },
{ " ", "", true },
{ " ", "", true },
{ "\n", "", true},
{ "\n\n\n", "", true },
{ "A", "A", false },
{ "A ", "A", true },
{ " A", "A", true },
{ " A ", "A", true },
{ "\nA\n", "A", true },
{ "A\n\n", "A", true },
{ "A\n\nB", "A\n\nB", false },
{ "A\n\nB ", "A\n\nB", true },
{ "A B", "A B", false },
});
}
public StringUtilTest_trim(CharSequence input, CharSequence expected, boolean changed) {
this.input = input;
this.expected = expected;
this.changed = changed;
}
@Test
public void trim() {
CharSequence output = StringUtil.trim(input);
assertEquals(expected, output);
if (changed) {
assertNotSame(output, input);
} else {
assertSame(output, input);
}
}
}

View file

@ -0,0 +1,228 @@
package org.signal.core.util;
import android.app.Application;
import android.os.Build;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;
import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeTrue;
@RunWith(RobolectricTestRunner.class)
@Config(manifest = Config.NONE, application = Application.class)
public final class StringUtilTest_trimToFit {
@Test
public void testShortStringIsNotTrimmed() {
assertEquals("Test string", StringUtil.trimToFit("Test string", 32));
assertEquals("", StringUtil.trimToFit("", 32));
assertEquals("aaaBBBCCC", StringUtil.trimToFit("aaaBBBCCC", 9));
}
@Test
public void testNull() {
assertEquals("", StringUtil.trimToFit(null, 0));
assertEquals("", StringUtil.trimToFit(null, 1));
assertEquals("", StringUtil.trimToFit(null, 10));
}
@Test
public void testStringIsTrimmed() {
assertEquals("Test stri", StringUtil.trimToFit("Test string", 9));
assertEquals("aaaBBBCC", StringUtil.trimToFit("aaaBBBCCC", 8));
}
@Test
public void testStringWithControlCharsIsTrimmed() {
assertEquals("Test string\nwrap\r\nhere",
StringUtil.trimToFit("Test string\nwrap\r\nhere\tindent\n\n", 22));
}
@Test
public void testAccentedCharactersAreTrimmedCorrectly() {
assertEquals("", StringUtil.trimToFit("âëȋõṷ", 1));
assertEquals("â", StringUtil.trimToFit("âëȋõṷ", 2));
assertEquals("â", StringUtil.trimToFit("âëȋõṷ", 3));
assertEquals("âë", StringUtil.trimToFit("âëȋõṷ", 4));
assertEquals("The last characters take more than a byte in utf8 â",
StringUtil.trimToFit("The last characters take more than a byte in utf8 âëȋõṷ", 53));
assertEquals("un quinzième jour en jaune apr", StringUtil.trimToFit("un quinzième jour en jaune après son épopée de 2019", 32));
assertEquals("una vez se organizaron detrás l", StringUtil.trimToFit("una vez se organizaron detrás la ventaja nunca pasó de los 3 minutos.", 32));
}
@Test
public void testCombinedAccentsAreTrimmedAsACharacter() {
final String a = "a\u0302";
final String e = "e\u0308";
final String i = "i\u0311";
final String o = "o\u0303";
final String u = "u\u032d";
assertEquals("", StringUtil.trimToFit(a + e + i + o + u, 1));
assertEquals("", StringUtil.trimToFit(a + e + i + o + u, 2));
assertEquals(a, StringUtil.trimToFit(a + e + i + o + u, 3));
assertEquals(a, StringUtil.trimToFit(a + e + i + o + u, 4));
assertEquals(a, StringUtil.trimToFit(a + e + i + o + u, 5));
assertEquals(a + e, StringUtil.trimToFit(a + e + i + o + u, 6));
assertEquals("The last characters take more than a byte in utf8 " + a,
StringUtil.trimToFit("The last characters take more than a byte in utf8 " + a + e + i + o + u, 53));
assertEquals("un quinzie\u0300me jour en jaune apr", StringUtil.trimToFit("un quinzie\u0300me jour en jaune apre\u0300s son e\u0301pope\u0301e de 2019", 32));
assertEquals("una vez se organizaron detra\u0301s ", StringUtil.trimToFit("una vez se organizaron detra\u0301s la ventaja nunca paso\u0301 de los 3 minutos.", 32));
}
@Test
public void testCJKCharactersAreTrimmedCorrectly() {
final String shin = "\u4fe1";
final String signal = shin + "\u53f7";
final String _private = "\u79c1\u4eba";
final String messenger = "\u4fe1\u4f7f";
assertEquals("", StringUtil.trimToFit(signal, 1));
assertEquals("", StringUtil.trimToFit(signal, 2));
assertEquals(shin, StringUtil.trimToFit(signal, 3));
assertEquals(shin, StringUtil.trimToFit(signal, 4));
assertEquals(shin, StringUtil.trimToFit(signal, 5));
assertEquals(signal, StringUtil.trimToFit(signal, 6));
assertEquals(String.format("Signal %s Pr", signal),
StringUtil.trimToFit(String.format("Signal %s Private %s Messenger %s", signal, _private, messenger),
16));
}
@Test
public void testSurrogatePairsAreTrimmedCorrectly() {
final String sword = "\uD841\uDF4F";
assertEquals("", StringUtil.trimToFit(sword, 1));
assertEquals("", StringUtil.trimToFit(sword, 2));
assertEquals("", StringUtil.trimToFit(sword, 3));
assertEquals(sword, StringUtil.trimToFit(sword, 4));
final String so = "\ud869\uddf1";
final String go = "\ud869\ude1a";
assertEquals("", StringUtil.trimToFit(so + go, 1));
assertEquals("", StringUtil.trimToFit(so + go, 2));
assertEquals("", StringUtil.trimToFit(so + go, 3));
assertEquals(so, StringUtil.trimToFit(so + go, 4));
assertEquals(so, StringUtil.trimToFit(so + go, 5));
assertEquals(so, StringUtil.trimToFit(so + go, 6));
assertEquals(so, StringUtil.trimToFit(so + go, 7));
assertEquals(so + go, StringUtil.trimToFit(so + go, 8));
final String gClef = "\uD834\uDD1E";
final String fClef = "\uD834\uDD22";
assertEquals("", StringUtil.trimToFit(gClef + " " + fClef, 1));
assertEquals("", StringUtil.trimToFit(gClef + " " + fClef, 2));
assertEquals("", StringUtil.trimToFit(gClef + " " + fClef, 3));
assertEquals(gClef, StringUtil.trimToFit(gClef + " " + fClef, 4));
assertEquals(gClef + " ", StringUtil.trimToFit(gClef + " " + fClef, 5));
assertEquals(gClef + " ", StringUtil.trimToFit(gClef + " " + fClef, 6));
assertEquals(gClef + " ", StringUtil.trimToFit(gClef + " " + fClef, 7));
assertEquals(gClef + " ", StringUtil.trimToFit(gClef + " " + fClef, 8));
assertEquals(gClef + " " + fClef, StringUtil.trimToFit(gClef + " " + fClef, 9));
}
@Test
public void testSimpleEmojiTrimming() {
final String congrats = "\u3297";
assertEquals("", StringUtil.trimToFit(congrats, 1));
assertEquals("", StringUtil.trimToFit(congrats, 2));
assertEquals(congrats, StringUtil.trimToFit(congrats, 3));
final String eject = "\u23cf";
assertEquals("", StringUtil.trimToFit(eject, 1));
assertEquals("", StringUtil.trimToFit(eject, 2));
assertEquals(eject, StringUtil.trimToFit(eject, 3));
}
@Test
public void testEmojisSurrogatePairTrimming() {
final String grape = "🍇";
assertEquals("", StringUtil.trimToFit(grape, 1));
assertEquals("", StringUtil.trimToFit(grape, 2));
assertEquals("", StringUtil.trimToFit(grape, 3));
assertEquals(grape, StringUtil.trimToFit(grape, 4));
final String smile = "\uD83D\uDE42";
assertEquals("", StringUtil.trimToFit(smile, 1));
assertEquals("", StringUtil.trimToFit(smile, 2));
assertEquals("", StringUtil.trimToFit(smile, 3));
assertEquals(smile, StringUtil.trimToFit(smile, 4));
final String check = "\u2714"; // Simple emoji
assertEquals(check, StringUtil.trimToFit(check, 3));
final String secret = "\u3299"; // Simple emoji
assertEquals(secret, StringUtil.trimToFit(secret, 3));
final String phoneWithArrow = "\uD83D\uDCF2"; // Surrogate Pair emoji
assertEquals(phoneWithArrow, StringUtil.trimToFit(phoneWithArrow, 4));
assertEquals(phoneWithArrow + ":",
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 7));
assertEquals(phoneWithArrow + ":" + secret,
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 8));
assertEquals(phoneWithArrow + ":" + secret + ",",
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 9));
assertEquals(phoneWithArrow + ":" + secret + ", ",
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 10));
assertEquals(phoneWithArrow + ":" + secret + ", ",
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 11));
assertEquals(phoneWithArrow + ":" + secret + ", ",
StringUtil.trimToFit(phoneWithArrow + ":" + secret + ", " + check, 12));
}
@Test
public void testGraphemeClusterTrimming1() {
assumeTrue(Build.VERSION.SDK_INT >= 24);
final String alphas = "AAAAABBBBBCCCCCDDDDDEEEEE";
final String wavingHand = "\uD83D\uDC4B";
final String mediumDark = "\uD83C\uDFFE";
assertEquals(alphas, StringUtil.trimToFit(alphas + wavingHand + mediumDark, 32));
assertEquals(alphas + wavingHand + mediumDark, StringUtil.trimToFit(alphas + wavingHand + mediumDark, 33));
final String pads = "abcdefghijklm";
final String frowningPerson = "\uD83D\uDE4D";
final String female = "\u200D\u2640\uFE0F";
assertEquals(pads + frowningPerson + female,
StringUtil.trimToFit(pads + frowningPerson + female, 26));
assertEquals(pads + "n",
StringUtil.trimToFit(pads + "n" + frowningPerson + female, 26));
final String pads1 = "abcdef";
final String mediumSkin = "\uD83C\uDFFD";
assertEquals(pads1 + frowningPerson + mediumSkin + female,
StringUtil.trimToFit(pads1 + frowningPerson + mediumSkin + female, 26));
assertEquals(pads1 + "g",
StringUtil.trimToFit(pads1 + "g" + frowningPerson + mediumSkin + female, 26));
}
@Test
public void testGraphemeClusterTrimming2() {
assumeTrue(Build.VERSION.SDK_INT >= 24);
final String woman = "\uD83D\uDC69";
final String mediumDarkSkin = "\uD83C\uDFFE";
final String joint = "\u200D";
final String hands = "\uD83E\uDD1D";
final String man = "\uD83D\uDC68";
final String lightSkin = "\uD83C\uDFFB";
assertEquals(woman + mediumDarkSkin + joint + hands + joint + man + lightSkin,
StringUtil.trimToFit(woman + mediumDarkSkin + joint + hands + joint + man + lightSkin, 26));
assertEquals("a",
StringUtil.trimToFit("a" + woman + mediumDarkSkin + joint + hands + joint + man + lightSkin, 26));
final String pads = "abcdefghijk";
final String wheelchair = "\uD83E\uDDBC";
assertEquals(pads + man + lightSkin + joint + wheelchair,
StringUtil.trimToFit(pads + man + lightSkin + joint + wheelchair, 26));
assertEquals(pads + "l",
StringUtil.trimToFit(pads + "l" + man + lightSkin + joint + wheelchair, 26));
final String girl = "\uD83D\uDC67";
final String boy = "\uD83D\uDC66";
assertEquals(man + mediumDarkSkin + joint + man + joint + girl + lightSkin + joint + boy,
StringUtil.trimToFit(man + mediumDarkSkin + joint + man + joint + girl + lightSkin + joint + boy, 33));
assertEquals("a",
StringUtil.trimToFit("a" + man + mediumDarkSkin + joint + man + joint + girl + lightSkin + joint + boy, 33));
}
}

View file

@ -0,0 +1,56 @@
package org.signal.core.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import java.util.Arrays;
import java.util.Collection;
import static org.junit.Assert.assertEquals;
@RunWith(Parameterized.class)
public final class StringUtilTest_whitespace_handling {
private final String input;
private final String expectedTrimmed;
private final boolean isVisuallyEmpty;
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][]{
{ "", "", true },
{ " ", "", true },
{ "A", "A", false },
{ " B", "B", false },
{ "C ", "C", false },
/* Unicode whitespace */
{ "\u200E", "", true },
{ "\u200F", "", true },
{ "\u2007", "", true },
{ "\u200B", "", true },
{ "\u2800", "", true },
{ "\u2007\u200FA\tB\u200EC\u200E\u200F", "A\tB\u200EC", false },
});
}
public StringUtilTest_whitespace_handling(String input, String expectedTrimmed, boolean isVisuallyEmpty) {
this.input = input;
this.expectedTrimmed = expectedTrimmed;
this.isVisuallyEmpty = isVisuallyEmpty;
}
@Test
public void isVisuallyEmpty() {
assertEquals(isVisuallyEmpty, StringUtil.isVisuallyEmpty(input));
}
@Test
public void trim() {
assertEquals(expectedTrimmed, StringUtil.trimToVisualBounds(input));
}
}

View file

@ -0,0 +1,90 @@
package org.signal.core.util.concurrent
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import java.util.concurrent.Executor
class LatestPrioritizedSerialExecutorTest {
@Test
fun execute_sortsInPriorityOrder() {
val executor = TestExecutor()
val placeholder = TestRunnable()
val first = TestRunnable()
val second = TestRunnable()
val third = TestRunnable()
val subject = LatestPrioritizedSerialExecutor(executor)
subject.execute(0, placeholder) // The first thing we execute can't be sorted, so we put in this placeholder
subject.execute(1, third)
subject.execute(2, second)
subject.execute(3, first)
executor.next() // Clear the placeholder task
executor.next()
assertTrue(first.didRun)
executor.next()
assertTrue(second.didRun)
executor.next()
assertTrue(third.didRun)
}
@Test
fun execute_replacesDupes() {
val executor = TestExecutor()
val placeholder = TestRunnable()
val firstReplaced = TestRunnable()
val first = TestRunnable()
val second = TestRunnable()
val thirdReplaced = TestRunnable()
val third = TestRunnable()
val subject = LatestPrioritizedSerialExecutor(executor)
subject.execute(0, placeholder) // The first thing we execute can't be sorted, so we put in this placeholder
subject.execute(1, thirdReplaced)
subject.execute(1, third)
subject.execute(2, second)
subject.execute(3, firstReplaced)
subject.execute(3, first)
executor.next() // Clear the placeholder task
executor.next()
assertTrue(first.didRun)
executor.next()
assertTrue(second.didRun)
executor.next()
assertTrue(third.didRun)
assertFalse(firstReplaced.didRun)
assertFalse(thirdReplaced.didRun)
}
private class TestExecutor : Executor {
private val tasks = ArrayDeque<Runnable>()
override fun execute(command: Runnable) {
tasks.add(command)
}
fun next() {
tasks.removeLast().run()
}
}
class TestRunnable : Runnable {
private var _didRun = false
val didRun get() = _didRun
override fun run() {
_didRun = true
}
}
}

View file

@ -0,0 +1,29 @@
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.disposables.CompositeDisposable
import io.reactivex.rxjava3.subjects.BehaviorSubject
import io.reactivex.rxjava3.subjects.PublishSubject
import org.junit.Assert.assertEquals
import org.junit.Test
class RxExtensionsTest {
@Test
fun `Given a subject, when I subscribeWithBehaviorSubject, then I expect proper disposals`() {
val subject = PublishSubject.create<Int>()
val disposables = CompositeDisposable()
val sub2 = subject.subscribeWithSubject(
BehaviorSubject.create(),
disposables
)
val obs = sub2.test()
subject.onNext(1)
obs.dispose()
subject.onNext(2)
disposables.dispose()
subject.onNext(3)
obs.assertValues(1)
assertEquals(sub2.value, 2)
}
}

View file

@ -0,0 +1,104 @@
package org.signal.core.util.money
import org.junit.Assert.assertEquals
import org.junit.Test
import java.math.BigDecimal
import java.util.Currency
class FiatMoneyTest {
@Test
fun given100USD_whenIGetDefaultPrecisionString_thenIExpect100dot00() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("USD"))
// WHEN
val result = fiatMoney.defaultPrecisionString
// THEN
assertEquals("100.00", result)
}
@Test
fun given100USD_whenIGetMinimumUnitPrecisionString_thenIExpect10000() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("USD"))
// WHEN
val result = fiatMoney.minimumUnitPrecisionString
// THEN
assertEquals("10000", result)
}
@Test
fun given100JPY_whenIGetDefaultPrecisionString_thenIExpect100() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("JPY"))
// WHEN
val result = fiatMoney.defaultPrecisionString
// THEN
assertEquals("100", result)
}
@Test
fun given100JPY_whenIGetMinimumUnitPrecisionString_thenIExpect100() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("JPY"))
// WHEN
val result = fiatMoney.minimumUnitPrecisionString
// THEN
assertEquals("100", result)
}
@Test
fun given100UGX_whenIGetDefaultPrecisionString_thenIExpect100() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("UGX"))
// WHEN
val result = fiatMoney.defaultPrecisionString
// THEN
assertEquals("100", result)
}
@Test
fun given100UGX_whenIGetMinimumUnitPrecisionString_thenIExpect10000() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("UGX"))
// WHEN
val result = fiatMoney.minimumUnitPrecisionString
// THEN
assertEquals("10000", result)
}
@Test
fun given100ISK_whenIGetDefaultPrecisionString_thenIExpect100() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("ISK"))
// WHEN
val result = fiatMoney.defaultPrecisionString
// THEN
assertEquals("100", result)
}
@Test
fun given100ISK_whenIGetMinimumUnitPrecisionString_thenIExpect10000() {
// GIVEN
val fiatMoney = FiatMoney(BigDecimal.valueOf(100), Currency.getInstance("ISK"))
// WHEN
val result = fiatMoney.minimumUnitPrecisionString
// THEN
assertEquals("10000", result)
}
}