Source added

This commit is contained in:
Fr4nz D13trich 2025-11-20 09:26:33 +01:00
parent b2864b500e
commit ba28ca859e
8352 changed files with 1487182 additions and 1 deletions

1
core-util-jvm/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
/build

View file

@ -0,0 +1,63 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
val signalJavaVersion: JavaVersion by rootProject.extra
val signalKotlinJvmTarget: String by rootProject.extra
plugins {
id("java-library")
id("org.jetbrains.kotlin.jvm")
id("ktlint")
id("com.squareup.wire")
}
java {
sourceCompatibility = signalJavaVersion
targetCompatibility = signalJavaVersion
}
kotlin {
jvmToolchain {
languageVersion = JavaLanguageVersion.of(signalKotlinJvmTarget)
}
}
afterEvaluate {
listOf(
"runKtlintCheckOverMainSourceSet",
"runKtlintFormatOverMainSourceSet"
).forEach { taskName ->
tasks.named(taskName) {
mustRunAfter(tasks.named("generateMainProtos"))
}
}
}
wire {
kotlin {
javaInterop = true
}
sourcePath {
srcDir("src/main/protowire")
}
}
tasks.runKtlintCheckOverMainSourceSet {
dependsOn(":core-util-jvm:generateMainProtos")
}
dependencies {
implementation(libs.kotlin.reflect)
implementation(libs.kotlinx.coroutines.core)
implementation(libs.kotlinx.coroutines.core.jvm)
implementation(libs.google.libphonenumber)
implementation(libs.rxjava3.rxjava)
implementation(libs.rxjava3.rxkotlin)
testImplementation(testLibs.junit.junit)
testImplementation(testLibs.assertk)
testImplementation(testLibs.kotlinx.coroutines.test)
}

View file

@ -0,0 +1,153 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.io.IOException
import java.io.UnsupportedEncodingException
object Base64 {
/**
* Encodes the bytes as a normal Base64 string with padding. Not URL safe. For url-safe, use [encodeUrlSafe].
*
* Note: the [offset] and [length] are there to support a legacy usecase, which is why they're not present on
* the other encode* methods.
*/
@JvmOverloads
@JvmStatic
fun encodeWithPadding(bytes: ByteArray, offset: Int = 0, length: Int = bytes.size): String {
return Base64Tools.encodeBytes(bytes, offset, length)
}
/**
* Encodes the bytes as a normal Base64 string without padding. Not URL safe. For url-safe, use [encodeUrlSafe].
*/
@JvmStatic
fun encodeWithoutPadding(bytes: ByteArray): String {
return Base64Tools.encodeBytes(bytes).stripPadding()
}
/**
* Encodes the bytes as a url-safe Base64 string with padding. It basically replaces the '+' and '/' characters in the
* normal encoding scheme with '-' and '_'.
*/
@JvmStatic
fun encodeUrlSafeWithPadding(bytes: ByteArray): String {
return Base64Tools.encodeBytes(bytes, Base64Tools.URL_SAFE or Base64Tools.DONT_GUNZIP)
}
/**
* Encodes the bytes as a url-safe Base64 string without padding. It basically replaces the '+' and '/' characters in the
* normal encoding scheme with '-' and '_'.
*/
@JvmStatic
fun encodeUrlSafeWithoutPadding(bytes: ByteArray): String {
return Base64Tools.encodeBytes(bytes, Base64Tools.URL_SAFE or Base64Tools.DONT_GUNZIP).stripPadding()
}
/**
* A very lenient decoder. Does not care about the presence of padding or whether it's url-safe or not. It'll just decode it.
*/
@Throws(IOException::class)
@JvmStatic
fun decode(value: String): ByteArray {
return if (value.contains('-') || value.contains('_')) {
Base64Tools.decode(value.withPaddingIfNeeded(), Base64Tools.URL_SAFE or Base64Tools.DONT_GUNZIP)
} else {
Base64Tools.decode(value.withPaddingIfNeeded())
}
}
@JvmStatic
fun decode(value: ByteArray): ByteArray {
// This pattern of trying US_ASCII first mimics how Base64Tools handles strings
return try {
decode(String(value, Charsets.US_ASCII))
} catch (e: UnsupportedEncodingException) {
decode(String(value, Charsets.UTF_8))
}
}
/**
* The same as [decode], except that instead of requiring you to handle an exception, this will return null
* if the input is null or cannot be decoded.
*/
@JvmStatic
fun decodeOrNull(value: String?): ByteArray? {
if (value == null) {
return null
}
return try {
decode(value)
} catch (e: IOException) {
null
}
}
/**
* The same as [decode], except that instead of requiring you to handle an exception, this will just crash on invalid base64 strings.
* Should only be used if the value is definitely a valid base64 string.
*/
@JvmStatic
fun decodeOrThrow(value: String): ByteArray {
return try {
decode(value)
} catch (e: IOException) {
throw AssertionError(e)
}
}
/**
* The same as [decode], except that instead of requiring you to handle an exception, this will just crash on invalid base64 strings.
* It also allows null inputs. If the input is null, the outpul will be null.
* Should only be used if the value is definitely a valid base64 string.
*/
@JvmStatic
fun decodeNullableOrThrow(value: String?): ByteArray? {
if (value == null) {
return null
}
return try {
decode(value)
} catch (e: IOException) {
throw AssertionError(e)
}
}
private fun String.withPaddingIfNeeded(): String {
return when (this.length % 4) {
2 -> "$this=="
3 -> "$this="
else -> this
}
}
private fun String.stripPadding(): String {
return this.replace("=", "")
}
fun String.decodeBase64OrThrow(): ByteArray {
return try {
decode(this)
} catch (e: IOException) {
throw AssertionError("Invalid Base64 string: $this", e)
}
}
fun String?.decodeBase64(): ByteArray? {
if (this == null) {
return null
}
return try {
decode(this)
} catch (e: IOException) {
return null
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,159 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.util.regex.Pattern
object BidiUtil {
private val ALL_ASCII_PATTERN: Pattern = Pattern.compile("^[\\x00-\\x7F]*$")
private object Bidi {
/** Override text direction */
val OVERRIDES: Set<Int> = SetUtil.newHashSet(
"\u202a".codePointAt(0), // LRE
"\u202b".codePointAt(0), // RLE
"\u202d".codePointAt(0), // LRO
"\u202e".codePointAt(0) // RLO
)
/** Set direction and isolate surrounding text */
val ISOLATES: Set<Int> = SetUtil.newHashSet(
"\u2066".codePointAt(0), // LRI
"\u2067".codePointAt(0), // RLI
"\u2068".codePointAt(0) // FSI
)
/** Closes things in [.OVERRIDES] */
val PDF: Int = "\u202c".codePointAt(0)
/** Closes things in [.ISOLATES] */
val PDI: Int = "\u2069".codePointAt(0)
/** Auto-detecting isolate */
val FSI: Int = "\u2068".codePointAt(0)
}
/**
* @return True if the provided text contains a mix of LTR and RTL characters, otherwise false.
*/
@JvmStatic
fun hasMixedTextDirection(text: CharSequence?): Boolean {
if (text == null) {
return false
}
var isLtr: Boolean? = null
var i = 0
val len = Character.codePointCount(text, 0, text.length)
while (i < len) {
val codePoint = Character.codePointAt(text, i)
val direction = Character.getDirectionality(codePoint)
val isLetter = Character.isLetter(codePoint)
if (isLtr != null && isLtr && direction != Character.DIRECTIONALITY_LEFT_TO_RIGHT && isLetter) {
return true
} else if (isLtr != null && !isLtr && direction != Character.DIRECTIONALITY_RIGHT_TO_LEFT && isLetter) {
return true
} else if (isLetter) {
isLtr = direction == Character.DIRECTIONALITY_LEFT_TO_RIGHT
}
i++
}
return false
}
/**
* Isolates bi-directional text from influencing surrounding text. You should use this whenever
* you're injecting user-generated text into a larger string.
*
* You'd think we'd be able to trust BidiFormatter, but unfortunately it just misses some
* corner cases, so here we are.
*
* The general idea is just to balance out the opening and closing codepoints, and then wrap the
* whole thing in FSI/PDI to isolate it.
*
* For more details, see:
* https://www.w3.org/International/questions/qa-bidi-unicode-controls
*/
@JvmStatic
fun isolateBidi(text: String?): String {
if (text == null) {
return ""
}
if (text.isEmpty()) {
return text
}
if (ALL_ASCII_PATTERN.matcher(text).matches()) {
return text
}
var overrideCount = 0
var overrideCloseCount = 0
var isolateCount = 0
var isolateCloseCount = 0
var i = 0
val len = text.codePointCount(0, text.length)
while (i < len) {
val codePoint = text.codePointAt(i)
if (Bidi.OVERRIDES.contains(codePoint)) {
overrideCount++
} else if (codePoint == Bidi.PDF) {
overrideCloseCount++
} else if (Bidi.ISOLATES.contains(codePoint)) {
isolateCount++
} else if (codePoint == Bidi.PDI) {
isolateCloseCount++
}
i++
}
val suffix = StringBuilder()
while (overrideCount > overrideCloseCount) {
suffix.appendCodePoint(Bidi.PDF)
overrideCloseCount++
}
while (isolateCount > isolateCloseCount) {
suffix.appendCodePoint(Bidi.FSI)
isolateCloseCount++
}
val out = StringBuilder()
return out.appendCodePoint(Bidi.FSI)
.append(text)
.append(suffix)
.appendCodePoint(Bidi.PDI)
.toString()
}
@JvmStatic
fun stripBidiProtection(text: String?): String? {
if (text == null) return null
return text.replace("[\\u2068\\u2069\\u202c]".toRegex(), "")
}
fun stripBidiIndicator(text: String): String {
return text.replace("\u200F", "")
}
@JvmStatic
fun forceLtr(text: CharSequence): String {
return "\u202a" + text + "\u202c"
}
fun stripAllDirectionalCharacters(text: String): String {
return text.replace("[\\u200f\\u2066\\u2067\\u2068\\u2069\\u202a\\u202b\\u202c\\u202d\\u202e]".toRegex(), "")
}
}

View file

@ -0,0 +1,83 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util;
import java.util.Locale;
/**
* A set of utilities to make working with Bitmasks easier.
*/
public final class Bitmask {
/**
* Reads a bitmasked boolean from a long at the requested position.
*/
public static boolean read(long value, int position) {
return read(value, position, 1) > 0;
}
/**
* Reads a bitmasked value from a long at the requested position.
*
* @param value The value your are reading state from
* @param position The position you'd like to read from
* @param flagBitSize How many bits are in each flag
* @return The value at the requested position
*/
public static long read(long value, int position, int flagBitSize) {
checkArgument(flagBitSize >= 0, "Must have a positive bit size! size: " + flagBitSize);
int bitsToShift = position * flagBitSize;
checkArgument(bitsToShift + flagBitSize <= 64 && position >= 0, String.format(Locale.US, "Your position is out of bounds! position: %d, flagBitSize: %d", position, flagBitSize));
long shifted = value >>> bitsToShift;
long mask = twoToThe(flagBitSize) - 1;
return shifted & mask;
}
/**
* Sets the value at the specified position in a single-bit bitmasked long.
*/
public static long update(long existing, int position, boolean value) {
return update(existing, position, 1, value ? 1 : 0);
}
/**
* Updates the value in a bitmasked long.
*
* @param existing The existing state of the bitmask
* @param position The position you'd like to update
* @param flagBitSize How many bits are in each flag
* @param value The value you'd like to set at the specified position
* @return The updated bitmask
*/
public static long update(long existing, int position, int flagBitSize, long value) {
checkArgument(flagBitSize >= 0, "Must have a positive bit size! size: " + flagBitSize);
checkArgument(value >= 0, "Value must be positive! value: " + value);
checkArgument(value < twoToThe(flagBitSize), String.format(Locale.US, "Value is larger than you can hold for the given bitsize! value: %d, flagBitSize: %d", value, flagBitSize));
int bitsToShift = position * flagBitSize;
checkArgument(bitsToShift + flagBitSize <= 64 && position >= 0, String.format(Locale.US, "Your position is out of bounds! position: %d, flagBitSize: %d", position, flagBitSize));
long clearMask = ~((twoToThe(flagBitSize) - 1) << bitsToShift);
long cleared = existing & clearMask;
long shiftedValue = value << bitsToShift;
return cleared | shiftedValue;
}
/** Simple method to do 2^n. Giving it a name just so it's clear what's happening. */
private static long twoToThe(long n) {
return 1 << n;
}
private static void checkArgument(boolean state, String message) {
if (!state) {
throw new IllegalArgumentException(message);
}
}
}

View file

@ -0,0 +1,143 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.text.NumberFormat
import kotlin.math.min
inline val Long.bytes: ByteSize
get() = ByteSize(this)
inline val Int.bytes: ByteSize
get() = ByteSize(this.toLong())
inline val Long.kibiBytes: ByteSize
get() = (this * 1024).bytes
inline val Int.kibiBytes: ByteSize
get() = (this.toLong() * 1024L).bytes
inline val Long.mebiBytes: ByteSize
get() = (this * 1024L).kibiBytes
inline val Int.mebiBytes: ByteSize
get() = (this.toLong() * 1024L).kibiBytes
inline val Long.gibiBytes: ByteSize
get() = (this * 1024L).mebiBytes
inline val Int.gibiBytes: ByteSize
get() = (this.toLong() * 1024L).mebiBytes
inline val Long.tebiBytes: ByteSize
get() = (this * 1024L).gibiBytes
inline val Int.tebiBytes: ByteSize
get() = (this.toLong() * 1024L).gibiBytes
class ByteSize(val bytes: Long) {
val inWholeBytes: Long
get() = bytes
val inWholeKibiBytes: Long
get() = bytes / 1024L
val inWholeMebiBytes: Long
get() = inWholeKibiBytes / 1024L
val inWholeGibiBytes: Long
get() = inWholeMebiBytes / 1024L
val inWholeTebiBytes: Long
get() = inWholeGibiBytes / 1024L
val inKibiBytes: Float
get() = bytes / 1024f
val inMebiBytes: Float
get() = inKibiBytes / 1024f
val inGibiBytes: Float
get() = inMebiBytes / 1024f
val inTebiBytes: Float
get() = inGibiBytes / 1024f
fun getLargestNonZeroValue(): Pair<Float, Size> {
return when {
inWholeTebiBytes > 0L -> inTebiBytes to Size.TEBIBYTE
inWholeGibiBytes > 0L -> inGibiBytes to Size.GIBIBYTE
inWholeMebiBytes > 0L -> inMebiBytes to Size.MEBIBYTE
inWholeKibiBytes > 0L -> inKibiBytes to Size.KIBIBYTE
else -> inWholeBytes.toFloat() to Size.BYTE
}
}
@JvmOverloads
fun toUnitString(maxPlaces: Int = 2, spaced: Boolean = true): String {
val (size, unit) = getLargestNonZeroValue()
val formatter = NumberFormat.getInstance().apply {
minimumFractionDigits = 0
maximumFractionDigits = when (unit) {
Size.BYTE,
Size.KIBIBYTE -> 0
Size.MEBIBYTE -> min(1, maxPlaces)
Size.GIBIBYTE,
Size.TEBIBYTE -> min(2, maxPlaces)
}
}
return BidiUtil.forceLtr("${formatter.format(size)}${if (spaced) " " else ""}${unit.label}")
}
operator fun compareTo(other: ByteSize): Int {
return bytes.compareTo(other.bytes)
}
operator fun plus(other: ByteSize): ByteSize {
return ByteSize(this.inWholeBytes + other.inWholeBytes)
}
fun percentageOf(other: ByteSize): Float {
return this.inWholeBytes.toFloat() / other.inWholeBytes.toFloat()
}
operator fun minus(other: ByteSize): ByteSize {
return ByteSize(this.inWholeBytes - other.inWholeBytes)
}
operator fun times(other: Long): ByteSize {
return ByteSize(this.inWholeBytes * other)
}
override fun toString(): String {
return "ByteSize(${toUnitString(maxPlaces = 4, spaced = false)})"
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false
other as ByteSize
return bytes == other.bytes
}
override fun hashCode(): Int {
return bytes.hashCode()
}
enum class Size(val label: String) {
BYTE("B"),
KIBIBYTE("KB"),
MEBIBYTE("MB"),
GIBIBYTE("GB"),
TEBIBYTE("TB")
}
}

View file

@ -0,0 +1,40 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
public final class CryptoUtil {
private static final String HMAC_SHA256 = "HmacSHA256";
private CryptoUtil() {
}
public static byte[] hmacSha256(byte[] key, byte[] data) {
try {
Mac mac = Mac.getInstance(HMAC_SHA256);
mac.init(new SecretKeySpec(key, HMAC_SHA256));
return mac.doFinal(data);
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new AssertionError(e);
}
}
public static byte[] sha256(byte[] data) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
return digest.digest(data);
} catch (NoSuchAlgorithmException e) {
throw new AssertionError(e);
}
}
}

View file

@ -0,0 +1,18 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.util.Locale
/**
* Rounds a number to the specified number of places. e.g.
*
* 1.123456f.roundedString(2) = 1.12
* 1.123456f.roundedString(5) = 1.12346
*/
fun Double.roundedString(places: Int): String {
return String.format(Locale.US, "%.${places}f", this)
}

View file

@ -0,0 +1,14 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlin.time.Duration
import kotlin.time.DurationUnit
fun Duration.inRoundedMilliseconds(places: Int = 2) = this.toDouble(DurationUnit.MILLISECONDS).roundedString(places)
fun Duration.inRoundedMinutes(places: Int = 2) = this.toDouble(DurationUnit.MINUTES).roundedString(places)
fun Duration.inRoundedHours(places: Int = 2) = this.toDouble(DurationUnit.HOURS).roundedString(places)
fun Duration.inRoundedDays(places: Int = 2) = this.toDouble(DurationUnit.DAYS).roundedString(places)

View file

@ -0,0 +1,331 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import com.google.i18n.phonenumbers.NumberParseException
import com.google.i18n.phonenumbers.PhoneNumberUtil
import com.google.i18n.phonenumbers.Phonenumber.PhoneNumber
import com.google.i18n.phonenumbers.ShortNumberInfo
import org.signal.core.util.logging.Log
import java.util.Locale
import java.util.Optional
import java.util.regex.Matcher
import java.util.regex.Pattern
/**
* Contains a bunch of utility functions to parse and format phone numbers.
*/
object E164Util {
private val TAG = Log.tag(E164Util::class)
private const val COUNTRY_CODE_BR = "55"
private const val COUNTRY_CODE_US = "1"
private val US_NO_AREACODE: Pattern = Pattern.compile("^(\\d{7})$")
private val BR_NO_AREACODE: Pattern = Pattern.compile("^(9?\\d{8})$")
private const val COUNTRY_CODE_US_INT = 1
private const val COUNTRY_CODE_UK_INT = 44
/** A set of country codes representing countries where we'd like to use the (555) 555-5555 number format for pretty printing. */
private val NATIONAL_FORMAT_COUNTRY_CODES = setOf(COUNTRY_CODE_US_INT, COUNTRY_CODE_UK_INT)
private val INVALID_CHARACTERS_REGEX = "[a-zA-Z]".toRegex()
/**
* Creates a formatter based on the provided local number. This is largely an improvement in performance/convenience
* over parsing out the various number attributes themselves and caching them manually.
*
* It is assumed that this number is properly formatted. If it is not, this may throw a [NumberParseException].
*
* @throws NumberParseException
*/
fun createFormatterForE164(localNumber: String): Formatter {
val phoneNumber = PhoneNumberUtil.getInstance().parse(localNumber, null)
val regionCode = PhoneNumberUtil.getInstance().getRegionCodeForNumber(phoneNumber) ?: PhoneNumberUtil.getInstance().getRegionCodeForCountryCode(phoneNumber.countryCode)
val areaCode = parseAreaCode(localNumber, phoneNumber.countryCode)
return Formatter(localNumber = phoneNumber, localAreaCode = areaCode, localRegionCode = regionCode)
}
/**
* Creates a formatter based on the provided region code. This is largely an improvement in performance/convenience
* over parsing out the various number attributes themselves and caching them manually.
*/
fun createFormatterForRegionCode(regionCode: String): Formatter {
return Formatter(localNumber = null, localAreaCode = null, localRegionCode = regionCode)
}
/**
* The same as [formatAsE164WithCountryCode], but if we determine the number to be invalid,
* we will do some cleanup to *roughly* format it as E164.
*
* IMPORTANT: Do not use this for actual number storage! There is no guarantee that this
* will be a properly-formatted E164 number. It should only be used in situations where a
* value is needed for user display.
*/
@JvmStatic
fun formatAsE164WithCountryCodeForDisplay(countryCode: String, input: String?): String {
val input = input ?: ""
val result: String? = formatAsE164WithCountryCode(countryCode, input)
if (result != null) {
return result
}
val cleanCountryCode = countryCode
.numbersOnly()
.replace("^0*".toRegex(), "")
val cleanNumber = input.numbersOnly()
return "+$cleanCountryCode$cleanNumber"
}
/**
* Returns whether or not an input number is valid for registration. Besides checking to ensure that libphonenumber thinks it's a possible number at all,
* we also have a few country-specific checks, as well as some of our own length and formatting checks.
*/
@JvmStatic
fun isValidNumberForRegistration(countryCode: String, input: String): Boolean {
if (!PhoneNumberUtil.getInstance().isPossibleNumber(input, countryCode)) {
Log.w(TAG, "Failed isPossibleNumber()")
return false
}
if (COUNTRY_CODE_US == countryCode && !Pattern.matches("^\\+1[0-9]{10}$", input)) {
Log.w(TAG, "Failed US number format check")
return false
}
if (COUNTRY_CODE_BR == countryCode && !Pattern.matches("^\\+55[0-9]{2}9?[0-9]{8}$", input)) {
Log.w(TAG, "Failed Brazil number format check")
return false
}
return input.matches("^\\+[1-9][0-9]{6,14}$".toRegex())
}
/**
* Given a regionCode, this will attempt to provide the display name for that region.
*/
@JvmStatic
fun getRegionDisplayName(regionCode: String?): Optional<String> {
if (regionCode == null || regionCode == "ZZ" || regionCode == PhoneNumberUtil.REGION_CODE_FOR_NON_GEO_ENTITY) {
return Optional.empty()
}
val displayCountry: String? = Locale("", regionCode).getDisplayCountry(Locale.getDefault()).nullIfBlank()
return Optional.ofNullable(displayCountry)
}
/**
* Identical to [formatAsE164WithRegionCode], except rather than supply the region code, you supply the
* country code (i.e. "1" for the US). This will convert the country code to a region code on your behalf.
* See [formatAsE164WithRegionCode] for behavior.
*/
private fun formatAsE164WithCountryCode(countryCode: String, input: String): String? {
val regionCode = try {
val countryCodeInt = countryCode.toInt()
PhoneNumberUtil.getInstance().getRegionCodeForCountryCode(countryCodeInt)
} catch (e: NumberFormatException) {
return null
}
return formatAsE164WithRegionCode(
localNumber = null,
localAreaCode = null,
regionCode = regionCode,
input = input
)
}
/**
* Formats the number as an E164, or null if the number cannot be reasonably interpreted as a phone number.
* This does not check if the number is *valid* for a given region. Instead, it's very lenient and just
* does it's best to interpret the input string as a number that could be put into the E164 format.
*
* Note that shortcodes will not have leading '+' signs.
*
* In other words, if this method returns null, you likely do not have anything that could be considered
* a phone number.
*/
private fun formatAsE164WithRegionCode(localNumber: PhoneNumber?, localAreaCode: String?, regionCode: String, input: String): String? {
try {
val correctedInput = input.e164CharsOnly().stripLeadingZerosFromInput()
if (correctedInput.trimStart('0').length < 3) {
return null
}
val withAreaCodeRules: String = applyAreaCodeRules(localNumber, localAreaCode, correctedInput)
val parsedNumber: PhoneNumber = PhoneNumberUtil.getInstance().parse(withAreaCodeRules, regionCode)
val isShortCode = ShortNumberInfo.getInstance().isValidShortNumberForRegion(parsedNumber, regionCode) || withAreaCodeRules.trimStart('+').length <= 6
if (isShortCode) {
return correctedInput.numbersOnly().stripLeadingZerosFromE164()
}
return PhoneNumberUtil.getInstance().format(parsedNumber, PhoneNumberUtil.PhoneNumberFormat.E164).stripLeadingZerosFromE164()
} catch (e: NumberParseException) {
return null
}
}
/**
* Strictly checks if a given number is a valid short code for a given region. Short code length varies by region and some
* require specific prefixes.
*
* This will check the input with and without a leading '+' sign.
*
* If the number cannot be parsed or is otherwise invalid, false is returned.
*/
private fun isValidShortNumber(regionCode: String, input: String): Boolean {
try {
val correctedInput = input.e164CharsOnly().stripLeadingZerosFromInput()
val correctedWithoutLeading = correctedInput.trimStart('+', '0')
var parsedNumber: PhoneNumber = PhoneNumberUtil.getInstance().parse(correctedInput, regionCode)
val isShortCode = ShortNumberInfo.getInstance().isValidShortNumberForRegion(parsedNumber, regionCode) || correctedWithoutLeading.length <= 6
if (isShortCode) {
return true
}
if (correctedInput != correctedWithoutLeading) {
parsedNumber = PhoneNumberUtil.getInstance().parse(correctedInput.trimStart('+', '0'), regionCode)
return ShortNumberInfo.getInstance().isValidShortNumberForRegion(parsedNumber, regionCode) || correctedInput.length <= 6
}
return false
} catch (_: NumberParseException) {
return false
}
}
/**
* Attempts to parse the area code out of an e164-formatted number provided that it's in one of the supported countries.
*/
private fun parseAreaCode(e164Number: String, countryCode: Int): String? {
when (countryCode) {
1 -> return e164Number.substring(2, 5)
55 -> return e164Number.substring(3, 5)
}
return null
}
/**
* Given an input number, this will attempt to add in an area code for certain locales if we have one in the local number.
* For example, in the US, if your local number is (610) 555-5555, and we're given a `testNumber` of 123-4567, we could
* assume that the full number would be (610) 123-4567.
*/
private fun applyAreaCodeRules(localNumber: PhoneNumber?, localAreaCode: String?, testNumber: String): String {
if (localNumber === null || localAreaCode == null) {
return testNumber
}
val matcher: Matcher? = when (localNumber.countryCode) {
1 -> US_NO_AREACODE.matcher(testNumber)
55 -> BR_NO_AREACODE.matcher(testNumber)
else -> null
}
if (matcher != null && matcher.matches()) {
return localAreaCode + matcher.group()
}
return testNumber
}
private fun String.numbersOnly(): String {
return this.filter { it.isDigit() }
}
private fun String.e164CharsOnly(): String {
return this.filter { it.isDigit() || it == '+' }
}
/**
* Strips out bad leading zeros from input strings that can confuse libphonenumber.
*/
private fun String.stripLeadingZerosFromInput(): String {
return if (this.startsWith("+0")) {
"+" + this.substring(1).trimStart('0')
} else {
this
}
}
/**
* Strips out leading zeros from a string after it's been e164-formatted by libphonenumber.
*/
private fun String.stripLeadingZerosFromE164(): String {
return if (this.startsWith("0")) {
this.trimStart('0')
} else if (this.startsWith("+0")) {
"+" + this.substring(1).trimStart('0')
} else {
this
}
}
class Formatter(
val localNumber: PhoneNumber?,
val localAreaCode: String?,
val localRegionCode: String
) {
/**
* Formats the number as an E164, or null if the number cannot be reasonably interpreted as a phone number.
* This does not check if the number is *valid* for a given region. Instead, it's very lenient and just
* does it's best to interpret the input string as a number that could be put into the E164 format.
*
* Note that shortcodes will not have leading '+' signs.
*
* In other words, if this method returns null, you likely do not have anything that could be considered
* a phone number.
*/
fun formatAsE164(input: String): String? {
if (INVALID_CHARACTERS_REGEX.containsMatchIn(input)) {
return null
}
val formatted = formatAsE164WithRegionCode(
localNumber = localNumber,
localAreaCode = localAreaCode,
regionCode = localRegionCode,
input = input
)
return if (formatted == null && input.startsWith("+")) {
formatAsE164(input.substring(1))
} else {
formatted
}
}
fun isValidShortNumber(input: String): Boolean {
return isValidShortNumber(localRegionCode, input)
}
/**
* Formats the number for human-readable display. e.g. "(555) 555-5555"
*/
fun prettyPrint(input: String): String {
val raw = try {
val parsedNumber: PhoneNumber = PhoneNumberUtil.getInstance().parse(input, localRegionCode)
return if (localNumber != null && localNumber.countryCode == parsedNumber.countryCode && NATIONAL_FORMAT_COUNTRY_CODES.contains(localNumber.countryCode)) {
PhoneNumberUtil.getInstance().format(parsedNumber, PhoneNumberUtil.PhoneNumberFormat.NATIONAL)
} else {
PhoneNumberUtil.getInstance().format(parsedNumber, PhoneNumberUtil.PhoneNumberFormat.INTERNATIONAL)
}
} catch (e: NumberParseException) {
Log.w(TAG, "Failed to format number: $e")
input
}
return BidiUtil.forceLtr(BidiUtil.isolateBidi(raw))
}
}
}

View file

@ -0,0 +1,16 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
/**
* Rounds a number to the specified number of places. e.g.
*
* 1.123456f.roundedString(2) = 1.12
* 1.123456f.roundedString(5) = 1.12346
*/
fun Float.roundedString(places: Int): String {
return String.format("%.${places}f", this)
}

View file

@ -0,0 +1,37 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.conflate
import kotlinx.coroutines.flow.filterNot
import kotlinx.coroutines.flow.onEach
import kotlin.time.Duration
/**
* Throttles the flow so that at most one value is emitted every [timeout]. The latest value is always emitted.
*
* You can think of this like debouncing, but with "checkpoints" so that even if you have a constant stream of values,
* you'll still get an emission every [timeout] (unlike debouncing, which will only emit once the stream settles down).
*
* You can specify an optional [emitImmediately] function that will indicate whether an emission should skip throttling and
* be emitted immediately. This lambda should be stateless, as it may be called multiple times for each item.
*/
fun <T> Flow<T>.throttleLatest(timeout: Duration, emitImmediately: (T) -> Boolean = { false }): Flow<T> {
val rootFlow = this
return channelFlow {
rootFlow
.onEach { if (emitImmediately(it)) send(it) }
.filterNot { emitImmediately(it) }
.conflate()
.collect {
send(it)
delay(timeout)
}
}
}

View file

@ -0,0 +1,134 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util;
import java.io.IOException;
/**
* Utility for generating hex dumps.
*/
public class Hex {
private final static int HEX_DIGITS_START = 10;
private final static int ASCII_TEXT_START = HEX_DIGITS_START + (16*2 + (16/2));
final static String EOL = System.lineSeparator();
private final static char[] HEX_DIGITS = {
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'
};
public static String toString(byte[] bytes) {
return toString(bytes, 0, bytes.length);
}
public static String toString(byte[] bytes, int offset, int length) {
StringBuffer buf = new StringBuffer();
for (int i = 0; i < length; i++) {
appendHexChar(buf, bytes[offset + i]);
buf.append(' ');
}
return buf.toString();
}
public static String toStringCondensed(byte[] bytes) {
StringBuffer buf = new StringBuffer();
for (int i=0;i<bytes.length;i++) {
appendHexChar(buf, bytes[i]);
}
return buf.toString();
}
public static byte[] fromStringCondensed(String encoded) throws IOException {
final char[] data = encoded.toCharArray();
final int len = data.length;
if ((len & 0x01) != 0) {
throw new IOException("Odd number of characters.");
}
final byte[] out = new byte[len >> 1];
// two characters form the hex value.
for (int i = 0, j = 0; j < len; i++) {
int f = Character.digit(data[j], 16) << 4;
j++;
f = f | Character.digit(data[j], 16);
j++;
out[i] = (byte) (f & 0xFF);
}
return out;
}
public static byte[] fromStringOrThrow(String encoded) {
try {
return fromStringCondensed(encoded);
} catch (IOException e) {
throw new AssertionError(e);
}
}
public static String dump(byte[] bytes) {
return dump(bytes, 0, bytes.length);
}
public static String dump(byte[] bytes, int offset, int length) {
StringBuffer buf = new StringBuffer();
int lines = ((length - 1) / 16) + 1;
int lineOffset;
int lineLength;
for (int i = 0; i < lines; i++) {
lineOffset = (i * 16) + offset;
lineLength = Math.min(16, (length - (i * 16)));
appendDumpLine(buf, i, bytes, lineOffset, lineLength);
buf.append(EOL);
}
return buf.toString();
}
private static void appendDumpLine(StringBuffer buf, int line, byte[] bytes, int lineOffset, int lineLength) {
buf.append(HEX_DIGITS[(line >> 28) & 0xf]);
buf.append(HEX_DIGITS[(line >> 24) & 0xf]);
buf.append(HEX_DIGITS[(line >> 20) & 0xf]);
buf.append(HEX_DIGITS[(line >> 16) & 0xf]);
buf.append(HEX_DIGITS[(line >> 12) & 0xf]);
buf.append(HEX_DIGITS[(line >> 8) & 0xf]);
buf.append(HEX_DIGITS[(line >> 4) & 0xf]);
buf.append(HEX_DIGITS[(line ) & 0xf]);
buf.append(": ");
for (int i = 0; i < 16; i++) {
int idx = i + lineOffset;
if (i < lineLength) {
int b = bytes[idx];
appendHexChar(buf, b);
} else {
buf.append(" ");
}
if ((i % 2) == 1) {
buf.append(' ');
}
}
for (int i = 0; i < 16 && i < lineLength; i++) {
int idx = i + lineOffset;
int b = bytes[idx];
if (b >= 0x20 && b <= 0x7e) {
buf.append((char)b);
} else {
buf.append('.');
}
}
}
private static void appendHexChar(StringBuffer buf, int b) {
buf.append(HEX_DIGITS[(b >> 4) & 0xf]);
buf.append(HEX_DIGITS[b & 0xf]);
}
}

View file

@ -0,0 +1,150 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.signal.core.util.stream.LimitedInputStream
import java.io.ByteArrayOutputStream
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
import kotlin.math.min
/**
* Reads a 32-bit variable-length integer from the stream.
*
* The format uses one byte for each 7 bits of the integer, with the most significant bit (MSB) of each byte indicating whether more bytes need to be read.
* If the MSB is 0, it indicates the final byte. The actual integer value is constructed from the remaining 7 bits of each byte.
*/
fun InputStream.readVarInt32(): Int {
var result = 0
// We read 7 bits of the integer at a time, up to the full size of an integer (32 bits).
for (shift in 0 until 32 step 7) {
// Despite returning an int, the range of the returned value is 0..255, so it's just a byte.
// I believe it's an int just so it can return -1 when the stream ends.
val byte: Int = read()
if (byte < 0) {
return -1
}
val lowestSevenBits = byte and 0x7F
val shiftedBits = lowestSevenBits shl shift
result = result or shiftedBits
// If the MSB is 0, that means the varint is finished, and we have our full result
if (byte and 0x80 == 0) {
return result
}
}
throw IOException("Malformed varint!")
}
/**
* Reads the entire stream into a [ByteArray].
*/
@Throws(IOException::class)
fun InputStream.readFully(autoClose: Boolean = true): ByteArray {
return StreamUtil.readFully(this, Integer.MAX_VALUE, autoClose)
}
/**
* Fills reads data from the stream into the [buffer] until it is full.
* Throws an [IOException] if the stream doesn't have enough data to fill the buffer.
*/
@Throws(IOException::class)
fun InputStream.readFully(buffer: ByteArray) {
return StreamUtil.readFully(this, buffer)
}
/**
* Reads the specified number of bytes from the stream and returns it as a [ByteArray].
* Throws an [IOException] if the stream doesn't have that many bytes.
*/
@Throws(IOException::class)
fun InputStream.readNBytesOrThrow(length: Int): ByteArray {
val buffer = ByteArray(length)
this.readFully(buffer)
return buffer
}
/**
* Read at most [byteLimit] bytes from the stream.
*/
fun InputStream.readAtMostNBytes(byteLimit: Int): ByteArray {
val buffer = ByteArrayOutputStream()
val readBuffer = ByteArray(4096)
var remaining = byteLimit
while (remaining > 0) {
val bytesToRead = min(remaining, readBuffer.size)
val read = this.read(readBuffer, 0, bytesToRead)
if (read == -1) {
break
}
buffer.write(readBuffer, 0, read)
remaining -= read
}
return buffer.toByteArray()
}
@Throws(IOException::class)
fun InputStream.readLength(): Long {
val buffer = ByteArray(4096)
var count = 0L
while (this.read(buffer).also { if (it > 0) count += it } != -1) {
// do nothing, all work is in the while condition
}
return count
}
/**
* Reads the contents of the stream and discards them.
*/
@Throws(IOException::class)
fun InputStream.drain() {
this.readLength()
}
/**
* Returns a [LimitedInputStream] that will limit the number of bytes read from this stream to [limit].
*/
fun InputStream.limit(limit: Long): LimitedInputStream {
return LimitedInputStream(this, limit)
}
/**
* Copies the contents of this stream to the [outputStream].
*
* @param closeInputStream If true, the input stream will be closed after the copy is complete.
*/
fun InputStream.copyTo(outputStream: OutputStream, closeInputStream: Boolean = true, closeOutputStream: Boolean = true): Long {
return StreamUtil.copy(this, outputStream, closeInputStream, closeOutputStream)
}
/**
* Returns true if every byte in this stream matches the predicate, otherwise false.
*/
fun InputStream.allMatch(predicate: (Byte) -> Boolean): Boolean {
val buffer = ByteArray(4096)
var readCount: Int
while (this.read(buffer).also { readCount = it } != -1) {
for (i in 0 until readCount) {
if (!predicate(buffer[i])) {
return false
}
}
}
return true
}

View file

@ -0,0 +1,18 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.nio.ByteBuffer
/**
* Converts the integer into [ByteArray].
*/
fun Int.toByteArray(): ByteArray {
return ByteBuffer
.allocate(Int.SIZE_BYTES)
.putInt(this)
.array()
}

View file

@ -0,0 +1,18 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.nio.ByteBuffer
/**
* Converts the long into [ByteArray].
*/
fun Long.toByteArray(): ByteArray {
return ByteBuffer
.allocate(Long.SIZE_BYTES)
.putLong(this)
.array()
}

View file

@ -0,0 +1,28 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.util.Optional
fun <E> Optional<E>.or(other: Optional<E>): Optional<E> {
return if (this.isPresent) {
this
} else {
other
}
}
fun <E> Optional<E>.isAbsent(): Boolean {
return !isPresent
}
fun <E : Any> E?.toOptional(): Optional<E> {
return Optional.ofNullable(this)
}
fun <E> Optional<E>.orNull(): E? {
return orElse(null)
}

View file

@ -0,0 +1,32 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import java.io.OutputStream
/**
* Writes a 32-bit variable-length integer to the stream.
*
* The format uses one byte for each 7 bits of the integer, with the most significant bit (MSB) of each byte indicating whether more bytes need to be read.
*/
fun OutputStream.writeVarInt32(value: Int) {
var remaining = value
while (true) {
// We write 7 bits of the integer at a time
val lowestSevenBits = remaining and 0x7F
remaining = remaining ushr 7
if (remaining == 0) {
// If there are no more bits to write, we're done
write(lowestSevenBits)
return
} else {
// Otherwise, we need to write the next 7 bits, and set the MSB to 1 to indicate that there are more bits to come
write(lowestSevenBits or 0x80)
}
}
}

View file

@ -0,0 +1,111 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("ProtoUtil")
package org.signal.core.util
import com.squareup.wire.FieldEncoding
import com.squareup.wire.Message
import com.squareup.wire.ProtoAdapter
import com.squareup.wire.ProtoReader
import com.squareup.wire.ProtoWriter
import okio.Buffer
import okio.ByteString
import org.signal.core.util.logging.Log
import java.io.IOException
import java.util.LinkedList
private const val TAG = "ProtoExtension"
fun ByteString?.isEmpty(): Boolean {
return this == null || this.size == 0
}
fun ByteString?.isNotEmpty(): Boolean {
return this != null && this.size > 0
}
fun ByteString?.isNullOrEmpty(): Boolean {
return this == null || this.size == 0
}
fun ByteString.nullIfEmpty(): ByteString? {
return if (this.isEmpty()) {
null
} else {
this
}
}
/**
* Performs the common pattern of attempting to decode a serialized proto and returning null if it fails to decode.
*/
fun <E> ProtoAdapter<E>.decodeOrNull(serialized: ByteArray): E? {
return try {
this.decode(serialized)
} catch (e: IOException) {
null
}
}
/**
* True if there are unknown fields anywhere inside the proto or its nested protos.
*/
fun Message<*, *>.hasUnknownFields(): Boolean {
val allProtos = this.getInnerProtos()
allProtos.add(this)
for (proto in allProtos) {
val unknownFields = proto.unknownFields
if (unknownFields.size > 0) {
return true
}
}
return false
}
fun Message<*, *>.getUnknownEnumValue(tag: Int): Int {
val reader = ProtoReader(Buffer().write(this.unknownFields))
reader.forEachTag { unknownTag ->
if (unknownTag == tag) {
return ProtoAdapter.INT32.decode(reader)
}
}
throw AssertionError("Tag $tag not found in unknown fields")
}
fun writeUnknownEnumValue(tag: Int, enumValue: Int): ByteString {
val buffer = Buffer()
val writer = ProtoWriter(buffer)
@Suppress("UNCHECKED_CAST")
(FieldEncoding.VARINT.rawProtoAdapter() as ProtoAdapter<Any>).encodeWithTag(writer, tag, enumValue.toLong())
return buffer.readByteString()
}
/**
* Recursively retrieves all inner complex proto types inside a given proto.
*/
private fun Message<*, *>.getInnerProtos(): MutableList<Message<*, *>> {
val innerProtos: MutableList<Message<*, *>> = LinkedList()
try {
val fields = this.javaClass.declaredFields
for (field in fields) {
if (Message::class.java.isAssignableFrom(field.type)) {
field.isAccessible = true
val inner = field[this] as? Message<*, *>
if (inner != null) {
innerProtos.add(inner)
innerProtos.addAll(inner.getInnerProtos())
}
}
}
} catch (e: IllegalAccessException) {
Log.w(TAG, "Failed to get inner protos!", e)
}
return innerProtos
}

View file

@ -0,0 +1,60 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlin.reflect.KProperty
/**
* Identical to Kotlin's built-in [lazy] delegate, but with a `reset` method that allows the value to be reset to it's default state (and therefore recomputed
* upon next access).
*/
fun <T> resettableLazy(initializer: () -> T): ResettableLazy<T> {
return ResettableLazy(initializer)
}
/**
* @see resettableLazy
*/
class ResettableLazy<T>(
val initializer: () -> T
) {
// We need to distinguish between a lazy value of null and a lazy value that has not been initialized yet
@Volatile
private var value: Any? = UNINITIALIZED
operator fun getValue(thisRef: Any?, property: KProperty<*>): T {
if (value === UNINITIALIZED) {
synchronized(this) {
if (value === UNINITIALIZED) {
value = initializer()
}
}
}
@Suppress("UNCHECKED_CAST")
return value as T
}
fun reset() {
value = UNINITIALIZED
}
fun isInitialized(): Boolean {
return value !== UNINITIALIZED
}
override fun toString(): String {
return if (isInitialized()) {
value.toString()
} else {
"Lazy value not initialized yet."
}
}
companion object {
private val UNINITIALIZED = Any()
}
}

View file

@ -0,0 +1,39 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Set;
public final class SetUtil {
private SetUtil() {}
public static <E> Set<E> intersection(Collection<E> a, Collection<E> b) {
Set<E> intersection = new LinkedHashSet<>(a);
intersection.retainAll(b);
return intersection;
}
public static <E> Set<E> difference(Collection<E> a, Collection<E> b) {
Set<E> difference = new LinkedHashSet<>(a);
difference.removeAll(b);
return difference;
}
public static <E> Set<E> union(Set<E> a, Set<E> b) {
Set<E> result = new LinkedHashSet<>(a);
result.addAll(b);
return result;
}
@SafeVarargs
public static <E> HashSet<E> newHashSet(E... elements) {
return new HashSet<>(Arrays.asList(elements));
}
}

View file

@ -0,0 +1,89 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.signal.core.util.logging.Log
import kotlin.time.Duration.Companion.nanoseconds
import kotlin.time.DurationUnit
import kotlin.time.measureTimedValue
/**
* Simple utility to easily track the time a multi-step operation takes via splits.
*
* e.g.
*
* ```kotlin
* val stopwatch = Stopwatch("my-event")
* stopwatch.split("split-1")
* stopwatch.split("split-2")
* stopwatch.split("split-3")
* stopwatch.stop(TAG)
* ```
*/
class Stopwatch @JvmOverloads constructor(private val title: String, private val decimalPlaces: Int = 0) {
private val startTimeNanos: Long = System.nanoTime()
private val splits: MutableList<Split> = mutableListOf()
/**
* Create a new split between now and the last event.
*/
fun split(label: String) {
val now = System.nanoTime()
val previousTime = if (splits.isEmpty()) {
startTimeNanos
} else {
splits.last().nanoTime
}
splits += Split(
nanoTime = now,
durationNanos = now - previousTime,
label = label
)
}
/**
* Stops the stopwatch and logs the results with the provided tag.
*/
fun stop(tag: String) {
Log.d(tag, stopAndGetLogString())
}
/**
* Similar to [stop], but instead of logging directly, this will return the log string.
*/
fun stopAndGetLogString(): String {
val now = System.nanoTime()
splits += Split(
nanoTime = now,
durationNanos = now - startTimeNanos,
label = "total"
)
val splitString = splits
.joinToString(separator = ", ", transform = { it.displayString(decimalPlaces) })
return "[$title] $splitString"
}
private data class Split(val nanoTime: Long, val durationNanos: Long, val label: String) {
fun displayString(decimalPlaces: Int): String {
val timeMs: String = durationNanos.nanoseconds.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)
return "$label: $timeMs"
}
}
}
/**
* Logs how long it takes to perform the operation.
*/
inline fun <T> logTime(tag: String, label: String, decimalPlaces: Int = 0, block: () -> T): T {
val result = measureTimedValue(block)
Log.d(tag, "$label: ${result.duration.toDouble(DurationUnit.MILLISECONDS).roundedString(decimalPlaces)}")
return result.value
}

View file

@ -0,0 +1,124 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util;
import org.signal.core.util.logging.Log;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* Utility methods for input and output streams.
*/
public final class StreamUtil {
private static final String TAG = Log.tag(StreamUtil.class);
private StreamUtil() {}
public static void close(Closeable closeable) {
if (closeable == null) return;
try {
closeable.close();
} catch (IOException e) {
Log.w(TAG, e);
}
}
public static long getStreamLength(InputStream in) throws IOException {
byte[] buffer = new byte[4096];
int totalSize = 0;
int read;
while ((read = in.read(buffer)) != -1) {
totalSize += read;
}
return totalSize;
}
public static void readFully(InputStream in, byte[] buffer) throws IOException {
readFully(in, buffer, buffer.length);
}
public static void readFully(InputStream in, byte[] buffer, int len) throws IOException {
int offset = 0;
for (;;) {
int read = in.read(buffer, offset, len - offset);
if (read == -1) throw new EOFException("Stream ended early, offset: " + offset + " len: " + len);
if (read + offset < len) offset += read;
else return;
}
}
public static byte[] readFully(InputStream in) throws IOException {
return readFully(in, Integer.MAX_VALUE);
}
public static byte[] readFully(InputStream in, int maxBytes) throws IOException {
return readFully(in, maxBytes, true);
}
public static byte[] readFully(InputStream in, int maxBytes, boolean closeWhenDone) throws IOException {
ByteArrayOutputStream bout = new ByteArrayOutputStream();
byte[] buffer = new byte[4096];
int totalRead = 0;
int read;
while ((read = in.read(buffer)) != -1) {
bout.write(buffer, 0, read);
totalRead += read;
if (totalRead > maxBytes) {
throw new IOException("Stream size limit exceeded");
}
}
if (closeWhenDone) {
in.close();
}
return bout.toByteArray();
}
public static String readFullyAsString(InputStream in) throws IOException {
return new String(readFully(in));
}
public static long copy(InputStream in, OutputStream out) throws IOException {
return copy(in, out, true, true);
}
public static long copy(InputStream in, OutputStream out, boolean closeInputStream, boolean closeOutputStream) throws IOException {
byte[] buffer = new byte[64 * 1024];
int read;
long total = 0;
while ((read = in.read(buffer)) != -1) {
out.write(buffer, 0, read);
total += read;
}
if (closeInputStream) {
in.close();
}
out.flush();
if (closeOutputStream) {
out.close();
}
return total;
}
}

View file

@ -0,0 +1,121 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import okio.utf8Size
import org.signal.core.util.logging.Log
import java.net.URLEncoder
import java.nio.ByteBuffer
import java.nio.CharBuffer
import java.nio.charset.StandardCharsets
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract
private const val TAG: String = "StringExtensions"
/**
* Treats the string as a serialized list of tokens and tells you if an item is present in the list.
* In addition to exact matches, this handles wildcards at the end of an item.
*
* e.g. a,b,c*,d
*/
fun String.asListContains(item: String): Boolean {
val items: List<String> = this
.split(",")
.map { it.trim() }
.filter { it.isNotEmpty() }
.toList()
val exactMatches = items.filter { it.last() != '*' }
val prefixMatches = items.filter { it.last() == '*' }
return exactMatches.contains(item) ||
prefixMatches
.map { it.substring(0, it.length - 1) }
.any { item.startsWith(it) }
}
fun String?.emptyIfNull(): String {
return this ?: ""
}
/**
* Turns a multi-line string into a single-line string stripped of indentation, separated by spaces instead of newlines.
*
* e.g.
*
* a
* b
* c
*
* turns into
*
* a b c
*/
fun String.toSingleLine(): String {
return this.trimIndent().split("\n").joinToString(separator = " ")
}
fun String?.nullIfEmpty(): String? {
return this?.ifEmpty {
null
}
}
fun String?.nullIfBlank(): String? {
return this?.ifBlank {
null
}
}
@OptIn(ExperimentalContracts::class)
fun CharSequence?.isNotNullOrBlank(): Boolean {
contract {
returns(true) implies (this@isNotNullOrBlank != null)
}
return !this.isNullOrBlank()
}
/**
* Encode this string in a url-safe way with UTF-8 encoding.
*/
fun String.urlEncode(): String {
return URLEncoder.encode(this, StandardCharsets.UTF_8.name())
}
/**
* Splits a string into two parts, such that the first part will be at most [byteLength] bytes long.
* The first item of the pair will be the shortened string, and the second item will be the remainder.
* Appending the two parts together will give you back the original string.
*
* If the input string is already less than [byteLength] bytes, the second item will be null.
*/
fun String.splitByByteLength(byteLength: Int): Pair<String, String?> {
if (this.utf8Size() <= byteLength) {
return this to null
}
val charBuffer = CharBuffer.wrap(this)
val encoder = Charsets.UTF_8.newEncoder()
val outputBuffer = ByteBuffer.allocate(byteLength)
encoder.encode(charBuffer, outputBuffer, true)
charBuffer.flip()
var firstPart = charBuffer.toString()
// Unfortunately some Android implementations will cause the charBuffer to go a step beyond what it should.
// It's always extremely close (in testing, only ever off by 1), but as a workaround, we chop off characters
// at the end until it fits. Bummer.
while (firstPart.utf8Size() > byteLength) {
Log.w(TAG, "Had to chop off a character to make it fit under the byte limit.")
firstPart = firstPart.substring(0, firstPart.length - 1)
}
val remainder = this.substring(firstPart.length)
return firstPart to remainder
}

View file

@ -0,0 +1,99 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* A future that allows you to have multiple ways to compute a result. If one fails, the calculation
* will fall back to the next in the list.
*
* You will only see a failure if the last attempt in the list fails.
*/
public final class CascadingFuture<T> implements ListenableFuture<T> {
private static final String TAG = CascadingFuture.class.getSimpleName();
private SettableFuture<T> result;
public CascadingFuture(List<Callable<ListenableFuture<T>>> callables, ExceptionChecker exceptionChecker) {
if (callables.isEmpty()) {
throw new IllegalArgumentException("Must have at least one callable!");
}
this.result = new SettableFuture<>();
doNext(new ArrayList<>(callables), exceptionChecker);
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return result.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return result.isCancelled();
}
@Override
public boolean isDone() {
return result.isDone();
}
@Override
public T get() throws ExecutionException, InterruptedException {
return result.get();
}
@Override
public T get(long timeout, TimeUnit unit) throws ExecutionException, InterruptedException, TimeoutException {
return result.get(timeout, unit);
}
@Override
public void addListener(Listener<T> listener) {
result.addListener(listener);
}
private void doNext(List<Callable<ListenableFuture<T>>> callables, ExceptionChecker exceptionChecker) {
Callable<ListenableFuture<T>> callable = callables.remove(0);
try {
ListenableFuture<T> future = callable.call();
future.addListener(new ListenableFuture.Listener<T>() {
@Override
public void onSuccess(T value) {
result.set(value);
}
@Override
public void onFailure(ExecutionException e) {
if (callables.isEmpty() || !exceptionChecker.shouldContinue(e)) {
result.setException(e.getCause());
} else if (!result.isCancelled()) {
doNext(callables, exceptionChecker);
}
}
});
} catch (Exception e) {
if (callables.isEmpty() || !exceptionChecker.shouldContinue(e)) {
result.setException(e.getCause());
} else if (!result.isCancelled()) {
doNext(callables, exceptionChecker);
}
}
}
public interface ExceptionChecker {
boolean shouldContinue(Exception e);
}
}

View file

@ -0,0 +1,78 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* Lets you perform a simple transform on the result of a future that maps it to a different value.
*/
class FutureMapTransformer<Input, Output> implements ListenableFuture<Output> {
private final ListenableFuture<Input> future;
private final FutureTransformers.Transformer<Input, Output> transformer;
FutureMapTransformer(ListenableFuture<Input> future, FutureTransformers.Transformer<Input, Output> transformer) {
this.future = future;
this.transformer = transformer;
}
@Override
public void addListener(Listener<Output> listener) {
future.addListener(new Listener<Input>() {
@Override
public void onSuccess(Input result) {
try {
listener.onSuccess(transformer.transform(result));
} catch (Exception e) {
listener.onFailure(new ExecutionException(e));
}
}
@Override
public void onFailure(ExecutionException e) {
listener.onFailure(e);
}
});
}
@Override
public boolean cancel(boolean b) {
return future.cancel(b);
}
@Override
public boolean isCancelled() {
return future.isCancelled();
}
@Override
public boolean isDone() {
return future.isDone();
}
@Override
public Output get() throws InterruptedException, ExecutionException {
Input input = future.get();
try {
return transformer.transform(input);
} catch (Exception e) {
throw new ExecutionException(e);
}
}
@Override
public Output get(long l, TimeUnit timeUnit) throws InterruptedException, ExecutionException, TimeoutException {
Input input = future.get(l, timeUnit);
try {
return transformer.transform(input);
} catch (Exception e) {
throw new ExecutionException(e);
}
}
}

View file

@ -0,0 +1,17 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent;
public final class FutureTransformers {
public static <Input, Output> ListenableFuture<Output> map(ListenableFuture<Input> future, Transformer<Input, Output> transformer) {
return new FutureMapTransformer<>(future, transformer);
}
public interface Transformer<Input, Output> {
Output transform(Input a) throws Exception;
}
}

View file

@ -0,0 +1,30 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
@file:JvmName("JvmRxExtensions")
package org.signal.core.util.concurrent
import io.reactivex.rxjava3.core.Single
/**
* Throw an [InterruptedException] if a [Single.blockingGet] call is interrupted. This can
* happen when being called by code already within an Rx chain that is disposed.
*
* [Single.blockingGet] is considered harmful and should not be used.
*/
@Throws(InterruptedException::class)
fun <T : Any> Single<T>.safeBlockingGet(): T {
try {
return blockingGet()
} catch (e: RuntimeException) {
val cause = e.cause
if (cause is InterruptedException) {
throw cause
} else {
throw e
}
}
}

View file

@ -0,0 +1,18 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
public interface ListenableFuture<T> extends Future<T> {
void addListener(Listener<T> listener);
public interface Listener<T> {
public void onSuccess(T result);
public void onFailure(ExecutionException e);
}
}

View file

@ -0,0 +1,146 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.concurrent;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class SettableFuture<T> implements ListenableFuture<T> {
private final List<Listener<T>> listeners = new LinkedList<>();
private boolean completed;
private boolean canceled;
private volatile T result;
private volatile Throwable exception;
public SettableFuture() { }
public SettableFuture(T value) {
this.result = value;
this.completed = true;
}
public SettableFuture(Throwable throwable) {
this.exception = throwable;
this.completed = true;
}
@Override
public synchronized boolean cancel(boolean mayInterruptIfRunning) {
if (!completed && !canceled) {
canceled = true;
return true;
}
return false;
}
@Override
public synchronized boolean isCancelled() {
return canceled;
}
@Override
public synchronized boolean isDone() {
return completed;
}
public boolean set(T result) {
synchronized (this) {
if (completed || canceled) return false;
this.result = result;
this.completed = true;
notifyAll();
}
notifyAllListeners();
return true;
}
public boolean setException(Throwable throwable) {
synchronized (this) {
if (completed || canceled) return false;
this.exception = throwable;
this.completed = true;
notifyAll();
}
notifyAllListeners();
return true;
}
public void deferTo(ListenableFuture<T> other) {
other.addListener(new Listener<T>() {
@Override
public void onSuccess(T result) {
SettableFuture.this.set(result);
}
@Override
public void onFailure(ExecutionException e) {
SettableFuture.this.setException(e.getCause());
}
});
}
@Override
public synchronized T get() throws InterruptedException, ExecutionException {
while (!completed) wait();
if (exception != null) throw new ExecutionException(exception);
else return result;
}
@Override
public synchronized T get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException
{
long startTime = System.currentTimeMillis();
while (!completed && System.currentTimeMillis() - startTime < unit.toMillis(timeout)) {
wait(unit.toMillis(timeout));
}
if (!completed) throw new TimeoutException();
else return get();
}
@Override
public void addListener(Listener<T> listener) {
synchronized (this) {
listeners.add(listener);
if (!completed) return;
}
notifyListener(listener);
}
private void notifyAllListeners() {
List<Listener<T>> localListeners;
synchronized (this) {
localListeners = new LinkedList<>(listeners);
}
for (Listener<T> listener : localListeners) {
notifyListener(listener);
}
}
private void notifyListener(Listener<T> listener) {
if (exception != null) listener.onFailure(new ExecutionException(exception));
else listener.onSuccess(result);
}
}

View file

@ -0,0 +1,48 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
/**
* A way to treat N loggers as one. Wraps a bunch of other loggers and forwards the method calls to
* all of them.
*/
internal class CompoundLogger(private val loggers: List<Log.Logger>) : Log.Logger() {
override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
for (logger in loggers) {
logger.v(tag, message, t, keepLonger)
}
}
override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
for (logger in loggers) {
logger.d(tag, message, t, keepLonger)
}
}
override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
for (logger in loggers) {
logger.i(tag, message, t, keepLonger)
}
}
override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
for (logger in loggers) {
logger.w(tag, message, t, keepLonger)
}
}
override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
for (logger in loggers) {
logger.e(tag, message, t, keepLonger)
}
}
override fun flush() {
for (logger in loggers) {
logger.flush()
}
}
}

View file

@ -0,0 +1,175 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
import kotlin.reflect.KClass
object Log {
private val NOOP_LOGGER: Logger = NoopLogger()
private var internalCheck: InternalCheck? = null
private var logger: Logger = NoopLogger()
/**
* @param internalCheck A checker that will indicate if this is an internal user
* @param loggers A list of loggers that will be given every log statement.
*/
@JvmStatic
fun initialize(internalCheck: InternalCheck?, vararg loggers: Logger) {
Log.internalCheck = internalCheck
logger = CompoundLogger(loggers.toList())
}
@JvmStatic
fun initialize(vararg loggers: Logger) {
initialize({ false }, *loggers)
}
@JvmStatic
fun v(tag: String, message: String) = v(tag, message, null)
@JvmStatic
fun v(tag: String, t: Throwable?) = v(tag, null, t)
@JvmStatic
fun v(tag: String, message: String?, t: Throwable?) = v(tag, message, t, false)
@JvmStatic
fun v(tag: String, message: String?, keepLonger: Boolean) = v(tag, message, null, keepLonger)
@JvmStatic
fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.v(tag, message, t, keepLonger)
@JvmStatic
fun d(tag: String, message: String) = d(tag, message, null)
@JvmStatic
fun d(tag: String, t: Throwable?) = d(tag, null, t)
@JvmStatic
fun d(tag: String, message: String?, t: Throwable? = null) = d(tag, message, t, false)
@JvmStatic
fun d(tag: String, message: String?, keepLonger: Boolean) = d(tag, message, null, keepLonger)
@JvmStatic
fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.d(tag, message, t, keepLonger)
@JvmStatic
fun i(tag: String, message: String) = i(tag, message, null)
@JvmStatic
fun i(tag: String, t: Throwable?) = i(tag, null, t)
@JvmStatic
fun i(tag: String, message: String?, t: Throwable? = null) = i(tag, message, t, false)
@JvmStatic
fun i(tag: String, message: String?, keepLonger: Boolean) = i(tag, message, null, keepLonger)
@JvmStatic
fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.i(tag, message, t, keepLonger)
@JvmStatic
fun w(tag: String, message: String) = w(tag, message, null)
@JvmStatic
fun w(tag: String, t: Throwable?) = w(tag, null, t)
@JvmStatic
fun w(tag: String, message: String?, t: Throwable? = null) = w(tag, message, t, false)
@JvmStatic
fun w(tag: String, message: String?, keepLonger: Boolean) = logger.w(tag, message, keepLonger)
@JvmStatic
fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) {
logger.w(tag, message, t, keepLonger)
}
@JvmStatic
fun e(tag: String, message: String) = e(tag, message, null)
@JvmStatic
fun e(tag: String, t: Throwable?) = e(tag, null, t)
@JvmStatic
fun e(tag: String, message: String?, t: Throwable? = null) = e(tag, message, t, false)
@JvmStatic
fun e(tag: String, message: String?, keepLonger: Boolean) = e(tag, message, null, keepLonger)
@JvmStatic
fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = logger.e(tag, message, t, keepLonger)
@JvmStatic
fun tag(clazz: KClass<*>): String {
return tag(clazz.java)
}
@JvmStatic
fun tag(clazz: Class<*>): String {
val simpleName = clazz.simpleName
return if (simpleName.length > 23) {
simpleName.substring(0, 23)
} else {
simpleName
}
}
/**
* Important: This is not something that can be used to log PII. Instead, it's intended use is for
* logs that might be too verbose or otherwise unnecessary for public users.
*
* @return The normal logger if this is an internal user, or a no-op logger if it isn't.
*/
@JvmStatic
fun internal(): Logger {
return if (internalCheck!!.isInternal()) {
logger
} else {
NOOP_LOGGER
}
}
@JvmStatic
fun blockUntilAllWritesFinished() {
logger.flush()
}
abstract class Logger {
abstract fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean)
abstract fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean)
abstract fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean)
abstract fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean)
abstract fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean)
abstract fun flush()
fun v(tag: String, message: String?) = v(tag, message, null)
fun v(tag: String, message: String?, t: Throwable?) = v(tag, message, t, false)
fun v(tag: String, message: String?, keepLonger: Boolean) = v(tag, message, null, keepLonger)
fun d(tag: String, message: String?) = d(tag, message, null)
fun d(tag: String, message: String?, t: Throwable?) = d(tag, message, t, false)
fun d(tag: String, message: String?, keepLonger: Boolean) = d(tag, message, null, keepLonger)
fun i(tag: String, message: String?) = i(tag, message, null)
fun i(tag: String, message: String?, t: Throwable?) = i(tag, message, t, false)
fun i(tag: String, message: String?, keepLonger: Boolean) = i(tag, message, null, keepLonger)
fun w(tag: String, message: String?) = w(tag, message, null)
fun w(tag: String, message: String?, t: Throwable?) = w(tag, message, t, false)
fun w(tag: String, message: String?, keepLonger: Boolean) = w(tag, message, null, keepLonger)
fun e(tag: String, message: String?) = e(tag, message, null)
fun e(tag: String, message: String?, t: Throwable?) = e(tag, message, t, false)
fun e(tag: String, message: String?, keepLonger: Boolean) = e(tag, message, null, keepLonger)
}
fun interface InternalCheck {
fun isInternal(): Boolean
}
}

View file

@ -0,0 +1,46 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
/**
* Convenience method to replace `.also { Log.v(TAG, "message") }`
*/
fun <T> T.logV(tag: String, message: String, throwable: Throwable? = null): T {
Log.v(tag, message, throwable)
return this
}
/**
* Convenience method to replace `.also { Log.d(TAG, "message") }`
*/
fun <T> T.logD(tag: String, message: String, throwable: Throwable? = null): T {
Log.d(tag, message, throwable)
return this
}
/**
* Convenience method to replace `.also { Log.i(TAG, "message") }`
*/
fun <T> T.logI(tag: String, message: String, throwable: Throwable? = null): T {
Log.i(tag, message, throwable)
return this
}
/**
* Convenience method to replace `.also { Log.w(TAG, "message") }`
*/
fun <T> T.logW(tag: String, message: String, throwable: Throwable? = null, keepLonger: Boolean = false): T {
Log.w(tag, message, throwable, keepLonger)
return this
}
/**
* Convenience method to replace `.also { Log.e(TAG, "message") }`
*/
fun <T> T.logE(tag: String, message: String, throwable: Throwable? = null): T {
Log.e(tag, message, throwable)
return this
}

View file

@ -0,0 +1,18 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
/**
* A logger that does nothing.
*/
internal class NoopLogger : Log.Logger() {
override fun v(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit
override fun d(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit
override fun i(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit
override fun w(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit
override fun e(tag: String, message: String?, t: Throwable?, keepLonger: Boolean) = Unit
override fun flush() = Unit
}

View file

@ -0,0 +1,271 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
import org.signal.core.util.CryptoUtil
import org.signal.core.util.Hex
import org.signal.core.util.isNotNullOrBlank
import java.util.regex.Matcher
import java.util.regex.Pattern
/** Given a [Matcher], update the [StringBuilder] with the scrubbed output you want for a given match. */
private typealias MatchProcessor = (Matcher, StringBuilder) -> Unit
/**
* Scrub data for possibly sensitive information.
*/
object Scrubber {
/**
* The middle group will be censored.
* Supposedly, the shortest international phone numbers in use contain seven digits.
* Handles URL encoded +, %2B
*/
private val E164_PATTERN = Pattern.compile("(KEEP_E164::)?(\\+|%2B)(\\d{7,15})")
private val E164_ZERO_PATTERN = Pattern.compile("\\b(KEEP_E164::)?0(\\d{10})\\b")
/** The second group will be censored.*/
private val CRUDE_EMAIL_PATTERN = Pattern.compile("\\b([^\\s/,()])([^\\s/,()]*@[^\\s]+\\.[^\\s]+)")
private const val EMAIL_CENSOR = "...@..."
/** The middle group will be censored. */
private val GROUP_ID_V1_PATTERN = Pattern.compile("(__textsecure_group__!)([^\\s]+)([^\\s]{3})")
/** The middle group will be censored. */
private val GROUP_ID_V2_PATTERN = Pattern.compile("(__signal_group__v2__!)([^\\s]+)([^\\s]{3})")
/** The middle group will be censored. */
private val UUID_PATTERN = Pattern.compile("(JOB::)?([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{9})([0-9a-f]{3})", Pattern.CASE_INSENSITIVE)
private const val UUID_CENSOR = "********-****-****-****-*********"
private val PNI_PATTERN = Pattern.compile("PNI:([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{9}[0-9a-f]{3})", Pattern.CASE_INSENSITIVE)
/**
* The entire string is censored. Note: left as concatenated strings because kotlin string literals leave trailing newlines, and removing them breaks
* syntax highlighting.
*/
private val IPV4_PATTERN = Pattern.compile(
"\\b" +
"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." +
"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." +
"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\." +
"(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)" +
"\\b"
)
private const val IPV4_CENSOR = "...ipv4..."
/** The entire string is censored. */
private val IPV6_PATTERN = Pattern.compile("([0-9a-fA-F]{0,4}:){3,7}([0-9a-fA-F]){0,4}")
private const val IPV6_CENSOR = "...ipv6..."
/** The domain name and path except for TLD will be censored. */
private val URL_PATTERN = Pattern.compile("([a-z0-9]+\\.)+([a-z0-9\\-]*[a-z\\-][a-z0-9\\-]*)(/[/a-z0-9\\-_.~:@?&=#%+\\[\\]!$()*,;]*)?", Pattern.CASE_INSENSITIVE)
private const val URL_CENSOR = "***"
private val TOP_100_TLDS: Set<String> = setOf(
"com", "net", "org", "jp", "de", "uk", "fr", "br", "it", "ru", "es", "me", "gov", "pl", "ca", "au", "cn", "co", "in",
"nl", "edu", "info", "eu", "ch", "id", "at", "kr", "cz", "mx", "be", "tv", "se", "tr", "tw", "al", "ua", "ir", "vn",
"cl", "sk", "ly", "cc", "to", "no", "fi", "us", "pt", "dk", "ar", "hu", "tk", "gr", "il", "news", "ro", "my", "biz",
"ie", "za", "nz", "sg", "ee", "th", "io", "xyz", "pe", "bg", "hk", "lt", "link", "ph", "club", "si", "site",
"mobi", "by", "cat", "wiki", "la", "ga", "xxx", "cf", "hr", "ng", "jobs", "online", "kz", "ug", "gq", "ae", "is",
"lv", "pro", "fm", "tips", "ms", "sa", "app"
)
/** Base16 Call Link Key Pattern */
private val CALL_LINK_PATTERN = Pattern.compile("([bBcCdDfFgGhHkKmMnNpPqQrRsStTxXzZ]{4})(-[bBcCdDfFgGhHkKmMnNpPqQrRsStTxXzZ]{4}){7}")
private const val CALL_LINK_CENSOR_SUFFIX = "-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
private val CALL_LINK_ROOM_ID_PATTERN = Pattern.compile("([^/])([0-9a-f]{61})([0-9a-f]{3})")
@JvmStatic
@Volatile
var identifierHmacKeyProvider: () -> ByteArray? = { null }
@JvmStatic
@Volatile
private var identifierHmacKey: ByteArray? = null
@JvmStatic
fun scrub(input: CharSequence): CharSequence {
return input
.scrubE164()
.scrubE164Zero()
.scrubEmail()
.scrubGroupsV1()
.scrubGroupsV2()
.scrubPnis()
.scrubUuids()
.scrubUrls()
.scrubIpv4()
.scrubIpv6()
.scrubCallLinkKeys()
.scrubCallLinkRoomIds()
}
private fun CharSequence.scrubE164(): CharSequence {
return scrub(this, E164_PATTERN) { matcher, output ->
if (matcher.group(1) != null && matcher.group(1)!!.isNotEmpty()) {
output
.append("KEEP_E164::")
.append((matcher.group(2) + matcher.group(3)).censorMiddle(2, 2))
} else {
output
.append("E164:")
.append(hash(matcher.group(3)))
}
}
}
private fun CharSequence.scrubE164Zero(): CharSequence {
return scrub(this, E164_ZERO_PATTERN) { matcher, output ->
if (matcher.group(1) != null && matcher.group(1)!!.isNotEmpty()) {
output
.append("KEEP_E164::")
.append(("0" + matcher.group(2)).censorMiddle(2, 2))
} else {
output
.append("E164:")
.append(hash(matcher.group(2)))
}
}
}
private fun CharSequence.scrubEmail(): CharSequence {
return scrub(this, CRUDE_EMAIL_PATTERN) { matcher, output ->
output
.append(matcher.group(1))
.append(EMAIL_CENSOR)
}
}
private fun CharSequence.scrubGroupsV1(): CharSequence {
return scrub(this, GROUP_ID_V1_PATTERN) { matcher, output ->
output
.append("GV1::***")
.append(matcher.group(3))
}
}
private fun CharSequence.scrubGroupsV2(): CharSequence {
return scrub(this, GROUP_ID_V2_PATTERN) { matcher, output ->
output
.append("GV2::***")
.append(matcher.group(3))
}
}
private fun CharSequence.scrubPnis(): CharSequence {
return scrub(this, PNI_PATTERN) { matcher, output ->
output
.append("PNI:")
.append(hash(matcher.group(1)))
}
}
private fun CharSequence.scrubUuids(): CharSequence {
return scrub(this, UUID_PATTERN) { matcher, output ->
if (matcher.group(1) != null && matcher.group(1)!!.isNotEmpty()) {
output
.append(matcher.group(1))
.append(matcher.group(2))
.append(matcher.group(3))
} else {
output
.append(UUID_CENSOR)
.append(matcher.group(3))
}
}
}
private fun CharSequence.scrubUrls(): CharSequence {
return scrub(this, URL_PATTERN) { matcher, output ->
val match: String = matcher.group(0)!!
if (
(matcher.groupCount() == 2 || matcher.groupCount() == 3) &&
TOP_100_TLDS.contains(matcher.group(2)!!.lowercase()) &&
!(matcher.group(1).endsWith("signal.") && matcher.group(2) == "org" && !match.contains("cdn")) &&
!(matcher.group(1).endsWith("debuglogs.") && matcher.group(2) == "org")
) {
output
.append(URL_CENSOR)
.append(".")
.append(matcher.group(2))
.run {
if (matcher.groupCount() == 3 && matcher.group(3).isNotNullOrBlank()) {
append("/")
append(URL_CENSOR)
}
}
} else {
output.append(match)
}
}
}
private fun CharSequence.scrubIpv4(): CharSequence {
return scrub(this, IPV4_PATTERN) { _, output -> output.append(IPV4_CENSOR) }
}
private fun CharSequence.scrubIpv6(): CharSequence {
return scrub(this, IPV6_PATTERN) { _, output -> output.append(IPV6_CENSOR) }
}
private fun CharSequence.scrubCallLinkKeys(): CharSequence {
return scrub(this, CALL_LINK_PATTERN) { matcher, output ->
val match = matcher.group(1)
output
.append(match)
.append(CALL_LINK_CENSOR_SUFFIX)
}
}
private fun CharSequence.scrubCallLinkRoomIds(): CharSequence {
return scrub(this, CALL_LINK_ROOM_ID_PATTERN) { matcher, output ->
output
.append(matcher.group(1))
.append("*************************************************************")
.append(matcher.group(3))
}
}
private fun String.censorMiddle(leading: Int, trailing: Int): String {
val totalKept = leading + trailing
if (this.length < totalKept) {
return "*".repeat(this.length)
}
val middle = "*".repeat(this.length - totalKept)
return this.take(leading) + middle + this.takeLast(trailing)
}
private fun scrub(input: CharSequence, pattern: Pattern, processMatch: MatchProcessor): CharSequence {
val output = StringBuilder(input.length)
val matcher: Matcher = pattern.matcher(input)
var lastEndingPos = 0
while (matcher.find()) {
output.append(input, lastEndingPos, matcher.start())
processMatch(matcher, output)
lastEndingPos = matcher.end()
}
return if (lastEndingPos == 0) {
// there were no matches, save copying all the data
input
} else {
output.append(input, lastEndingPos, input.length)
output
}
}
private fun hash(value: String): String {
if (identifierHmacKey == null) {
identifierHmacKey = identifierHmacKeyProvider()
}
val key: ByteArray = identifierHmacKey ?: return "<redacted>"
val hash = CryptoUtil.hmacSha256(key, value.toByteArray())
return "<${Hex.toStringCondensed(hash).take(5)}>"
}
}

View file

@ -0,0 +1,154 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.signal.core.util.logging.Log
import java.io.FilterInputStream
import java.io.InputStream
import java.lang.UnsupportedOperationException
import kotlin.math.min
/**
* An [InputStream] that will read from the target [InputStream] until it reaches the end, or until it has read [maxBytes] bytes.
*
* @param maxBytes The maximum number of bytes to read from the stream. If set to -1, there will be no limit.
*/
class LimitedInputStream(private val wrapped: InputStream, private val maxBytes: Long) : FilterInputStream(wrapped) {
private var totalBytesRead: Long = 0
private var lastMark = -1L
companion object {
private const val UNLIMITED = -1L
private val TAG = Log.tag(LimitedInputStream::class)
/**
* Returns a [LimitedInputStream] that doesn't limit the stream at all -- it'll allow reading the full thing.
*/
@JvmStatic
fun withoutLimits(wrapped: InputStream): LimitedInputStream {
return LimitedInputStream(wrapped = wrapped, maxBytes = UNLIMITED)
}
}
override fun read(): Int {
if (maxBytes == UNLIMITED) {
return wrapped.read()
}
if (totalBytesRead >= maxBytes) {
return -1
}
val read = wrapped.read()
if (read >= 0) {
totalBytesRead++
}
return read
}
override fun read(destination: ByteArray): Int {
return read(destination, 0, destination.size)
}
override fun read(destination: ByteArray, offset: Int, length: Int): Int {
if (maxBytes == UNLIMITED) {
return wrapped.read(destination, offset, length)
}
if (totalBytesRead >= maxBytes) {
return -1
}
val bytesRemaining: Long = maxBytes - totalBytesRead
val bytesToRead: Int = min(length, Math.toIntExact(bytesRemaining))
val bytesRead = wrapped.read(destination, offset, bytesToRead)
if (bytesRead > 0) {
totalBytesRead += bytesRead
}
return bytesRead
}
override fun skip(requestedSkipCount: Long): Long {
if (maxBytes == UNLIMITED) {
return wrapped.skip(requestedSkipCount)
}
val bytesRemaining: Long = maxBytes - totalBytesRead
val bytesToSkip: Long = min(bytesRemaining, requestedSkipCount)
val skipCount = super.skip(bytesToSkip)
totalBytesRead += skipCount
return skipCount
}
override fun available(): Int {
if (maxBytes == UNLIMITED) {
return wrapped.available()
}
val bytesRemaining = Math.toIntExact(maxBytes - totalBytesRead)
return min(bytesRemaining, wrapped.available())
}
override fun markSupported(): Boolean {
return wrapped.markSupported()
}
override fun mark(readlimit: Int) {
if (!markSupported()) {
throw UnsupportedOperationException("Mark not supported")
}
wrapped.mark(readlimit)
if (maxBytes == UNLIMITED) {
return
}
lastMark = totalBytesRead
}
override fun reset() {
if (!markSupported()) {
throw UnsupportedOperationException("Mark not supported")
}
if (lastMark == UNLIMITED) {
throw UnsupportedOperationException("Mark not set")
}
wrapped.reset()
if (maxBytes == UNLIMITED) {
return
}
totalBytesRead = lastMark
}
/**
* If the stream has been fully read, this will return a stream that contains the remaining bytes that were truncated.
* If the stream was setup with no limit, this will always return an empty stream.
*/
fun leftoverStream(): InputStream {
if (maxBytes == UNLIMITED) {
return ByteArray(0).inputStream()
}
if (totalBytesRead < maxBytes) {
Log.w(TAG, "Reading leftover stream when the stream has not been fully read! maxBytes is $maxBytes, but we've only read $totalBytesRead")
}
return wrapped
}
}

View file

@ -0,0 +1,43 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import java.io.FilterInputStream
import java.io.InputStream
import javax.crypto.Mac
/**
* Calculates a [Mac] as data is read from the target [InputStream].
* To get the final MAC, read the [mac] property after the stream has been fully read.
*
* Example:
* ```kotlin
* val stream = MacInputStream(myStream, myMac)
* stream.readFully()
* val mac = stream.mac.doFinal()
* ```
*/
class MacInputStream(val wrapped: InputStream, val mac: Mac) : FilterInputStream(wrapped) {
override fun read(): Int {
return wrapped.read().also { byte ->
if (byte >= 0) {
mac.update(byte.toByte())
}
}
}
override fun read(destination: ByteArray): Int {
return read(destination, 0, destination.size)
}
override fun read(destination: ByteArray, offset: Int, length: Int): Int {
return wrapped.read(destination, offset, length).also { bytesRead ->
if (bytesRead > 0) {
mac.update(destination, offset, bytesRead)
}
}
}
}

View file

@ -0,0 +1,37 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import java.io.FilterOutputStream
import java.io.OutputStream
import javax.crypto.Mac
/**
* Calculates a [Mac] as data is written to the target [OutputStream].
* To get the final MAC, read the [mac] property after the stream has been fully written.
*
* Example:
* ```kotlin
* val stream = MacOutputStream(myStream, myMac)
* // write data to stream
* val mac = stream.mac.doFinal()
* ```
*/
class MacOutputStream(val wrapped: OutputStream, val mac: Mac) : FilterOutputStream(wrapped) {
override fun write(byte: Int) {
wrapped.write(byte)
mac.update(byte.toByte())
}
override fun write(data: ByteArray) {
write(data, 0, data.size)
}
override fun write(data: ByteArray, offset: Int, length: Int) {
wrapped.write(data, offset, length)
mac.update(data, offset, length)
}
}

View file

@ -0,0 +1,19 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import java.io.FilterOutputStream
import java.io.OutputStream
/**
* Wraps a provided [OutputStream] but ignores calls to [OutputStream.close] on it but will call [OutputStream.flush] just in case.
* Wrappers must call [OutputStream.close] on the passed in [wrap] stream directly.
*/
class NonClosingOutputStream(wrap: OutputStream) : FilterOutputStream(wrap) {
override fun close() {
flush()
}
}

View file

@ -0,0 +1,17 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import java.io.OutputStream
/**
* An output stream that drops all data on the floor. Basically piping to /dev/null.
*/
object NullOutputStream : OutputStream() {
override fun write(b: Int) = Unit
override fun write(b: ByteArray?) = Unit
override fun write(b: ByteArray?, off: Int, len: Int) = Unit
}

View file

@ -0,0 +1,87 @@
package org.signal.core.util.stream
import org.signal.core.util.logging.Log
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
/**
* Input stream that reads a file that is actively being written to.
* Will read or wait to read (for the bytes to be available) until it reaches the end [bytesLength]
* A use case is streamable video where we want to play the video while the file is still downloading
*/
class TailerInputStream(private val streamFactory: StreamFactory, private val bytesLength: Long) : FilterInputStream(streamFactory.openStream()) {
private val TAG = Log.tag(TailerInputStream::class)
/** Tracks where we are in the file */
private var position: Long = 0
private var currentStream: InputStream
get() = this.`in`
set(input) {
this.`in` = input
}
override fun skip(requestedSkipCount: Long): Long {
val bytesSkipped = this.currentStream.skip(requestedSkipCount)
this.position += bytesSkipped
return bytesSkipped
}
override fun read(): Int {
val bytes = ByteArray(1)
var result = this.read(bytes)
while (result == 0) {
result = this.read(bytes)
}
if (result == -1) {
return result
}
return bytes[0].toInt() and 0xFF
}
override fun read(destination: ByteArray): Int {
return this.read(destination = destination, offset = 0, length = destination.size)
}
override fun read(destination: ByteArray, offset: Int, length: Int): Int {
// Checking if we reached the end of the file (bytesLength)
if (position >= bytesLength) {
return -1
}
var bytesRead = this.currentStream.read(destination, offset, length)
// If we haven't read any bytes, but we aren't at the end of the file,
// we close the stream, wait, and then try again
while (bytesRead < 0 && position < bytesLength) {
this.currentStream.close()
try {
Thread.sleep(100)
} catch (e: InterruptedException) {
Log.w(TAG, "Ignoring interrupted exception while waiting for input stream", e)
}
this.currentStream = streamFactory.openStream()
// After reopening the file, we skip to the position we were at last time
this.currentStream.skip(this.position)
bytesRead = this.currentStream.read(destination, offset, length)
}
// Update current position with bytes read
if (bytesRead > 0) {
position += bytesRead
}
return bytesRead
}
}
fun interface StreamFactory {
@Throws(IOException::class)
fun openStream(): InputStream
}

View file

@ -0,0 +1,148 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.signal.core.util.drain
import java.io.FilterInputStream
import java.io.IOException
import java.io.InputStream
import kotlin.math.min
/**
* An input stream that will read all but the last [trimSize] bytes of the stream.
*
* Important: we have to keep a buffer of size [trimSize] to ensure that we can avoid reading it.
* That means you should avoid using this for very large values of [trimSize].
*
* @param drain If true, the stream will be drained when it reaches the end (but bytes won't be returned). This is useful for ensuring that the underlying
* stream is fully consumed.
*/
class TrimmingInputStream(
private val inputStream: InputStream,
private val trimSize: Int,
private val drain: Boolean = false
) : FilterInputStream(inputStream) {
private val trimBuffer = ByteArray(trimSize)
private var trimBufferSize: Int = 0
private var streamEnded = false
private var hasDrained = false
private var internalBuffer = ByteArray(4096)
private var internalBufferPosition: Int = 0
private var internalBufferSize: Int = 0
@Throws(IOException::class)
override fun read(): Int {
val singleByteBuffer = ByteArray(1)
val bytesRead = read(singleByteBuffer, 0, 1)
return if (bytesRead == -1) {
-1
} else {
singleByteBuffer[0].toInt() and 0xFF
}
}
@Throws(IOException::class)
override fun read(b: ByteArray): Int {
return read(b, 0, b.size)
}
/**
* The general strategy is that we do bulk reads into an internal buffer (just for perf reasons), and then when new bytes are requested,
* we fill up a buffer of size [trimSize] with the most recent bytes, and then return the oldest byte from that buffer.
*
* This ensures that the last [trimSize] bytes are never returned, while still returning the rest of the bytes.
*
* When we hit the end of the stream, we stop returning bytes.
*/
@Throws(IOException::class)
override fun read(outputBuffer: ByteArray, outputOffset: Int, readLength: Int): Int {
if (streamEnded) {
return -1
}
if (trimSize == 0) {
return super.read(outputBuffer, outputOffset, readLength)
}
var outputCount = 0
while (outputCount < readLength) {
val nextByte = readNextByte()
if (nextByte == -1) {
streamEnded = true
drainIfNecessary()
break
}
if (trimBufferSize < trimSize) {
// Still filling the buffer - can't output anything yet
trimBuffer[trimBufferSize] = nextByte.toByte()
trimBufferSize++
} else {
// Buffer is full - output the oldest byte and add the new one
outputBuffer[outputOffset + outputCount] = trimBuffer[0]
outputCount++
// Shift buffer left and add new byte at the end. In practice, this is a tiny array and copies should be fast.
System.arraycopy(trimBuffer, 1, trimBuffer, 0, trimSize - 1)
trimBuffer[trimSize - 1] = nextByte.toByte()
}
}
return if (outputCount == 0) {
drainIfNecessary()
-1
} else {
outputCount
}
}
@Throws(IOException::class)
override fun skip(skipCount: Long): Long {
if (skipCount <= 0) return 0
var totalSkipped = 0L
val buffer = ByteArray(8192)
while (totalSkipped < skipCount) {
val toRead = min((skipCount - totalSkipped).toInt(), buffer.size)
val bytesRead = read(buffer, 0, toRead)
if (bytesRead == -1) {
break
}
totalSkipped += bytesRead
}
return totalSkipped
}
private fun readNextByte(): Int {
val hitEndOfStream = if (internalBufferPosition >= internalBufferSize) {
internalBufferPosition = 0
internalBufferSize = super.read(internalBuffer, 0, internalBuffer.size)
internalBufferSize <= 0
} else {
false
}
if (hitEndOfStream) {
drainIfNecessary()
return -1
}
return internalBuffer[internalBufferPosition++].toInt() and 0xFF
}
private fun drainIfNecessary() {
if (drain && !hasDrained) {
inputStream.drain()
hasDrained = true
}
}
}

View file

@ -0,0 +1,30 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.test
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.isAccessible
/**
* Returns a string containing the differences between the expected and actual objects.
* Useful for diffing complex data classes in your tests.
*/
inline fun <reified T : Any> getObjectDiff(expected: T, actual: T): String {
val builder = StringBuilder()
val properties = T::class.memberProperties
for (prop in properties) {
prop.isAccessible = true
val expectedValue = prop.get(expected)
val actualValue = prop.get(actual)
if (expectedValue != actualValue) {
builder.append("[${prop.name}] Expected: $expectedValue, Actual: $actualValue\n")
}
}
return builder.toString()
}

View file

@ -0,0 +1,36 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import kotlin.random.Random
class Base64Test {
@Test
fun `decode - correctly decode all strings regardless of url safety or padding`() {
val stopwatch = Stopwatch("time", 2)
for (len in 0 until 256) {
for (i in 0..2_000) {
val bytes = Random.nextBytes(len)
val padded = Base64.encodeWithPadding(bytes)
val unpadded = Base64.encodeWithoutPadding(bytes)
val urlSafePadded = Base64.encodeUrlSafeWithPadding(bytes)
val urlSafeUnpadded = Base64.encodeUrlSafeWithoutPadding(bytes)
assertArrayEquals(bytes, Base64.decode(padded))
assertArrayEquals(bytes, Base64.decode(unpadded))
assertArrayEquals(bytes, Base64.decode(urlSafePadded))
assertArrayEquals(bytes, Base64.decode(urlSafeUnpadded))
}
}
println(stopwatch.stopAndGetLogString())
}
}

View file

@ -0,0 +1,179 @@
/*
* Copyright 2025 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.junit.Assert
import org.junit.Test
class E164UtilTest {
@Test
fun `formatAsE164WithCountryCodeForDisplay - generic`() {
// UK
Assert.assertEquals("+442079460018", E164Util.formatAsE164WithCountryCodeForDisplay("44", "(020) 7946 0018"))
Assert.assertEquals("+442079460018", E164Util.formatAsE164WithCountryCodeForDisplay("44", "+442079460018"))
// CH
Assert.assertEquals("+41446681800", E164Util.formatAsE164WithCountryCodeForDisplay("41", "+41 44 668 18 00"))
Assert.assertEquals("+41446681800", E164Util.formatAsE164WithCountryCodeForDisplay("41", "+41 (044) 6681800"))
// DE
Assert.assertEquals("+4930123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0049 030 123456"))
Assert.assertEquals("+4930123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0049 (0)30123456"))
Assert.assertEquals("+4930123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0049((0)30)123456"))
Assert.assertEquals("+4930123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "+49 (0) 30 1 2 3 45 6 "))
Assert.assertEquals("+4930123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "030 123456"))
// DE
Assert.assertEquals("+49171123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0171123456"))
Assert.assertEquals("+49171123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0171/123456"))
Assert.assertEquals("+49171123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "+490171/123456"))
Assert.assertEquals("+49171123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "00490171/123456"))
Assert.assertEquals("+49171123456", E164Util.formatAsE164WithCountryCodeForDisplay("49", "0049171/123456"))
}
@Test
fun `formatAsE164 - generic`() {
var formatter: E164Util.Formatter = E164Util.createFormatterForE164("+14152222222")
Assert.assertEquals("+14151111122", formatter.formatAsE164("(415) 111-1122"))
Assert.assertEquals("+14151111123", formatter.formatAsE164("(415) 111 1123"))
Assert.assertEquals("+14151111124", formatter.formatAsE164("415-111-1124"))
Assert.assertEquals("+14151111125", formatter.formatAsE164("415.111.1125"))
Assert.assertEquals("+14151111126", formatter.formatAsE164("+1 415.111.1126"))
Assert.assertEquals("+14151111127", formatter.formatAsE164("+1 415 111 1127"))
Assert.assertEquals("+14151111128", formatter.formatAsE164("+1 (415) 111 1128"))
Assert.assertEquals("911", formatter.formatAsE164("911"))
Assert.assertEquals("+4567890", formatter.formatAsE164("+456-7890"))
formatter = E164Util.createFormatterForE164("+442079460010")
Assert.assertEquals("+442079460018", formatter.formatAsE164("(020) 7946 0018"))
}
@Test
fun `formatAsE164 - strip leading zeros`() {
var formatter: E164Util.Formatter = E164Util.createFormatterForE164("+14152222222")
Assert.assertEquals("+15551234567", formatter.formatAsE164("+015551234567"))
Assert.assertEquals("+15551234567", formatter.formatAsE164("+0015551234567"))
Assert.assertEquals("+15551234567", formatter.formatAsE164("01115551234567"))
Assert.assertEquals("1234", formatter.formatAsE164("01234"))
Assert.assertEquals(null, formatter.formatAsE164("0"))
Assert.assertEquals(null, formatter.formatAsE164("0000000"))
Assert.assertEquals("12345", formatter.formatAsE164("012345"))
formatter = E164Util.createFormatterForE164("+49 1234 567890")
Assert.assertEquals("+491234567890", formatter.formatAsE164("+0491234567890"))
Assert.assertEquals("+4912345", formatter.formatAsE164("+04912345"))
Assert.assertEquals("+491601234567", formatter.formatAsE164("+0491601234567"))
}
@Test
fun `formatAsE164 - US mix`() {
val formatter: E164Util.Formatter = E164Util.createFormatterForE164("+16105880522")
Assert.assertEquals("+551234567890", formatter.formatAsE164("+551234567890"))
Assert.assertEquals("+11234567890", formatter.formatAsE164("(123) 456-7890"))
Assert.assertEquals("+11234567890", formatter.formatAsE164("1234567890"))
Assert.assertEquals("+16104567890", formatter.formatAsE164("456-7890"))
Assert.assertEquals("+16104567890", formatter.formatAsE164("4567890"))
Assert.assertEquals("+11234567890", formatter.formatAsE164("011 1 123 456 7890"))
Assert.assertEquals("+5511912345678", formatter.formatAsE164("0115511912345678"))
Assert.assertEquals("+16105880522", formatter.formatAsE164("+16105880522"))
}
@Test
fun `formatAsE164 - Brazil mix`() {
val formatter: E164Util.Formatter = E164Util.createFormatterForE164("+5521912345678")
Assert.assertEquals("+16105880522", formatter.formatAsE164("+16105880522"))
Assert.assertEquals("+552187654321", formatter.formatAsE164("8765 4321"))
Assert.assertEquals("+5521987654321", formatter.formatAsE164("9 8765 4321"))
Assert.assertEquals("+552287654321", formatter.formatAsE164("22 8765 4321"))
Assert.assertEquals("+5522987654321", formatter.formatAsE164("22 9 8765 4321"))
Assert.assertEquals("+551234567890", formatter.formatAsE164("+55 (123) 456-7890"))
Assert.assertEquals("+14085048577", formatter.formatAsE164("002214085048577"))
Assert.assertEquals("+5511912345678", formatter.formatAsE164("011912345678"))
Assert.assertEquals("+5511912345678", formatter.formatAsE164("02111912345678"))
Assert.assertEquals("+551234567", formatter.formatAsE164("1234567"))
Assert.assertEquals("+5521912345678", formatter.formatAsE164("+5521912345678"))
Assert.assertEquals("+552112345678", formatter.formatAsE164("+552112345678"))
}
@Test
fun `formatAsE164 - short codes`() {
var formatter: E164Util.Formatter = E164Util.createFormatterForE164("+14152222222")
Assert.assertEquals("40404", formatter.formatAsE164("+40404"))
Assert.assertEquals("404040", formatter.formatAsE164("+404040"))
Assert.assertEquals("404040", formatter.formatAsE164("404040"))
Assert.assertEquals("49173", formatter.formatAsE164("+49173"))
Assert.assertEquals("7726", formatter.formatAsE164("+7726"))
Assert.assertEquals("69987", formatter.formatAsE164("+69987"))
Assert.assertEquals("40404", formatter.formatAsE164("40404"))
Assert.assertEquals("7726", formatter.formatAsE164("7726"))
Assert.assertEquals("22000", formatter.formatAsE164("22000"))
Assert.assertEquals("265080", formatter.formatAsE164("265080"))
Assert.assertEquals("32665", formatter.formatAsE164("32665"))
Assert.assertEquals("732873", formatter.formatAsE164("732873"))
Assert.assertEquals("73822", formatter.formatAsE164("73822"))
Assert.assertEquals("83547", formatter.formatAsE164("83547"))
Assert.assertEquals("84639", formatter.formatAsE164("84639"))
Assert.assertEquals("89887", formatter.formatAsE164("89887"))
Assert.assertEquals("99000", formatter.formatAsE164("99000"))
Assert.assertEquals("911", formatter.formatAsE164("911"))
Assert.assertEquals("112", formatter.formatAsE164("112"))
Assert.assertEquals("311", formatter.formatAsE164("311"))
Assert.assertEquals("611", formatter.formatAsE164("611"))
Assert.assertEquals("988", formatter.formatAsE164("988"))
Assert.assertEquals("999", formatter.formatAsE164("999"))
Assert.assertEquals("118", formatter.formatAsE164("118"))
Assert.assertEquals("20202", formatter.formatAsE164("020202"))
Assert.assertEquals("+119990001", formatter.formatAsE164("19990001"))
formatter = E164Util.createFormatterForE164("+61 2 9876 5432")
Assert.assertEquals("19990001", formatter.formatAsE164("19990001"))
}
@Test
fun `formatAsE164 - invalid`() {
val formatter: E164Util.Formatter = E164Util.createFormatterForE164("+14152222222")
Assert.assertEquals(null, formatter.formatAsE164("junk@junk.net"))
Assert.assertEquals(null, formatter.formatAsE164("__textsecure_group__!foobar"))
Assert.assertEquals(null, formatter.formatAsE164("bonbon"))
Assert.assertEquals(null, formatter.formatAsE164("44444444441234512312312312312312312312"))
Assert.assertEquals(null, formatter.formatAsE164("144444444441234512312312312312312312312"))
Assert.assertEquals(null, formatter.formatAsE164("1"))
Assert.assertEquals(null, formatter.formatAsE164("55"))
Assert.assertEquals(null, formatter.formatAsE164("0"))
Assert.assertEquals(null, formatter.formatAsE164("000"))
Assert.assertEquals(null, formatter.formatAsE164("+1555ABC4567"))
}
@Test
fun `formatAsE164 - no local number`() {
val formatter: E164Util.Formatter = E164Util.createFormatterForRegionCode("US")
Assert.assertEquals("+14151111122", formatter.formatAsE164("(415) 111-1122"))
}
@Test
fun `isValidShortNumber - multiple regions`() {
// India
var formatter: E164Util.Formatter = E164Util.createFormatterForE164("+911234567890")
Assert.assertTrue(formatter.isValidShortNumber("543212601"))
Assert.assertTrue(formatter.isValidShortNumber("+543212601"))
Assert.assertFalse(formatter.isValidShortNumber("1234567890"))
// Australia
formatter = E164Util.createFormatterForE164("+61111111111")
Assert.assertTrue(formatter.isValidShortNumber("1258881"))
Assert.assertTrue(formatter.isValidShortNumber("+1258881"))
Assert.assertFalse(formatter.isValidShortNumber("+111111111"))
// US
formatter = E164Util.createFormatterForE164("+15555555555")
Assert.assertTrue(formatter.isValidShortNumber("125811"))
Assert.assertTrue(formatter.isValidShortNumber("+121581"))
Assert.assertFalse(formatter.isValidShortNumber("+15555555555"))
}
}

View file

@ -0,0 +1,79 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Test
import kotlin.time.Duration.Companion.milliseconds
class FlowExtensionsTests {
@Test
fun `throttleLatest - always emits first value`() = runTest {
val testFlow = flow {
delay(10)
emit(1)
}
val output = testFlow
.throttleLatest(100.milliseconds)
.toList()
assertEquals(listOf(1), output)
}
@Test
fun `throttleLatest - always emits last value`() = runTest {
val testFlow = flow {
delay(10)
emit(1)
delay(30)
emit(2)
}
val output = testFlow
.throttleLatest(20.milliseconds)
.toList()
assertEquals(listOf(1, 2), output)
}
@Test
fun `throttleLatest - skips intermediate values`() = runTest {
val testFlow = flow {
for (i in 1..30) {
emit(i)
delay(10)
}
}
val output = testFlow
.throttleLatest(50.milliseconds)
.toList()
assertEquals(listOf(1, 5, 10, 15, 20, 25, 30), output)
}
@Test
fun `throttleLatest - respects skipThrottle`() = runTest {
val testFlow = flow {
for (i in 1..30) {
emit(i)
delay(10)
}
}
val output = testFlow
.throttleLatest(50.milliseconds) { it in setOf(2, 3, 4, 26, 27, 28) }
.toList()
assertEquals(listOf(1, 2, 3, 4, 5, 10, 15, 20, 25, 26, 27, 28, 30), output)
}
}

View file

@ -0,0 +1,88 @@
/*
* Copyright 2024 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import java.util.concurrent.CountDownLatch
class ResettableLazyTests {
@Test
fun `value only computed once`() {
var counter = 0
val lazy: Int by resettableLazy {
counter++
}
assertEquals(0, lazy)
assertEquals(0, lazy)
assertEquals(0, lazy)
}
@Test
fun `value recomputed after a reset`() {
var counter = 0
val _lazy = resettableLazy {
counter++
}
val lazy by _lazy
assertEquals(0, lazy)
_lazy.reset()
assertEquals(1, lazy)
_lazy.reset()
assertEquals(2, lazy)
}
@Test
fun `isInitialized - general`() {
val _lazy = resettableLazy { 1 }
val lazy: Int by _lazy
assertFalse(_lazy.isInitialized())
val x = lazy + 1
assertEquals(2, x)
assertTrue(_lazy.isInitialized())
_lazy.reset()
assertFalse(_lazy.isInitialized())
}
/**
* I've verified that without the synchronization inside of resettableLazy, this test usually fails.
*/
@Test
fun `ensure synchronization works`() {
val numRounds = 100
val numThreads = 5
for (i in 1..numRounds) {
var counter = 0
val lazy: Int by resettableLazy {
counter++
}
val latch = CountDownLatch(numThreads)
for (j in 1..numThreads) {
Thread {
val x = lazy + 1
latch.countDown()
}.start()
}
latch.await()
assertEquals(1, counter)
}
}
}

View file

@ -0,0 +1,75 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util
import org.junit.Assert.assertEquals
import org.junit.Ignore
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.util.concurrent.atomic.AtomicInteger
import kotlin.random.Random
class VarInt32Tests {
/**
* Tests a random sampling of integers. The faster and more practical version of [testAll].
*/
@Test
fun testRandomSampling() {
val randomInts = (0..100_000).map { Random.nextInt() }
val bytes = ByteArrayOutputStream().use { outputStream ->
for (value in randomInts) {
outputStream.writeVarInt32(value)
}
outputStream
}.toByteArray()
bytes.inputStream().use { inputStream ->
for (value in randomInts) {
val read = inputStream.readVarInt32()
assertEquals(value, read)
}
}
}
/**
* Exhaustively checks reading and writing a varint for all possible integers.
* We can't keep everything in memory, so instead we use sequences to grab a million at a time,
* then run smaller chunks of those in parallel.
*/
@Ignore("This test is very slow (over a minute). It was run once to verify correctness, but the random sampling test should be sufficient for catching regressions.")
@Test
fun testAll() {
val counter = AtomicInteger(0)
(Int.MIN_VALUE..Int.MAX_VALUE)
.asSequence()
.chunked(1_000_000)
.forEach { bigChunk ->
bigChunk
.chunked(100_000)
.parallelStream()
.forEach { smallChunk ->
println("Chunk ${counter.addAndGet(1)}")
val bytes = ByteArrayOutputStream().use { outputStream ->
for (value in smallChunk) {
outputStream.writeVarInt32(value)
}
outputStream
}.toByteArray()
bytes.inputStream().use { inputStream ->
for (value in smallChunk) {
val read = inputStream.readVarInt32()
assertEquals(value, read)
}
}
}
}
}
}

View file

@ -0,0 +1,36 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
import org.junit.Assert.assertEquals
import org.junit.Test
class LogTest {
@Test
fun tag_short_class_name() {
assertEquals("MyClass", Log.tag(MyClass::class))
}
@Test
fun tag_23_character_class_name() {
val tag = Log.tag(TwentyThreeCharacters23::class)
assertEquals("TwentyThreeCharacters23", tag)
assertEquals(23, tag.length)
}
@Test
fun tag_24_character_class_name() {
assertEquals(24, TwentyFour24Characters24::class.simpleName!!.length)
val tag = Log.tag(TwentyFour24Characters24::class)
assertEquals("TwentyFour24Characters2", tag)
assertEquals(23, Log.tag(TwentyThreeCharacters23::class).length)
}
private inner class MyClass
private inner class TwentyThreeCharacters23
private inner class TwentyFour24Characters24
}

View file

@ -0,0 +1,272 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.logging
import org.junit.Assert
import org.junit.BeforeClass
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Parameterized
@RunWith(Parameterized::class)
class ScrubberTest(private val input: String, private val expected: String) {
@Test
fun scrub() {
Assert.assertEquals(expected, Scrubber.scrub(input).toString())
}
companion object {
@JvmStatic
@BeforeClass
fun setup() {
Scrubber.identifierHmacKeyProvider = { ByteArray(32) }
}
@JvmStatic
@Parameterized.Parameters
fun data(): Iterable<Array<Any>> {
return listOf(
arrayOf(
"An E164 number +15551234567",
"An E164 number E164:<9f683>"
),
arrayOf(
"A UK number +447700900000",
"A UK number E164:<cad1f>"
),
arrayOf(
"A Japanese number 08011112222",
"A Japanese number E164:<d3f26>"
),
arrayOf(
"A Japanese number (08011112222)",
"A Japanese number (E164:<d3f26>)"
),
arrayOf(
"Not a Japanese number 08011112222333344445555",
"Not a Japanese number 08011112222333344445555"
),
arrayOf(
"Not a Japanese number 1234508011112222",
"Not a Japanese number 1234508011112222"
),
arrayOf(
"An avatar filename: file:///data/user/0/org.thoughtcrime.securesms/files/avatars/%2B447700900099",
"An avatar filename: file:///data/user/0/org.thoughtcrime.securesms/files/avatars/E164:<3106a>"
),
arrayOf(
"Multiple numbers +447700900001 +447700900002",
"Multiple numbers E164:<87035> E164:<1e488>"
),
arrayOf(
"One less than shortest number +155556",
"One less than shortest number +155556"
),
arrayOf(
"Shortest number +1555567",
"Shortest number E164:<8edd2>"
),
arrayOf(
"Longest number +155556789012345",
"Longest number E164:<90596>"
),
arrayOf(
"An E164 number KEEP_E164::+15551234567",
"An E164 number KEEP_E164::+1********67"
),
arrayOf(
"A UK number KEEP_E164::+447700900000",
"A UK number KEEP_E164::+4*********00"
),
arrayOf(
"A Japanese number KEEP_E164::08011112222",
"A Japanese number KEEP_E164::08*******22"
),
arrayOf(
"A Japanese number (KEEP_E164::08011112222)",
"A Japanese number (KEEP_E164::08*******22)"
),
arrayOf(
"One more than longest number +1234567890123456",
"One more than longest number E164:<78d5b>6"
),
arrayOf(
"abc@def.com",
"a...@..."
),
arrayOf(
"An email abc@def.com",
"An email a...@..."
),
arrayOf(
"A short email a@def.com",
"A short email a...@..."
),
arrayOf(
"This is not an email Success(result=org.whispersystems.signalservice.api.archive.ArchiveMediaResponse@1ea5e6)",
"This is not an email Success(result=org.whispersystems.signalservice.api.archive.ArchiveMediaResponse@1ea5e6)"
),
arrayOf(
"A email with multiple parts before the @ d.c+b.a@mulitpart.domain.com and a multipart domain",
"A email with multiple parts before the @ d...@... and a multipart domain"
),
arrayOf(
"An avatar email filename: file:///data/user/0/org.thoughtcrime.securesms/files/avatars/abc@signal.org",
"An avatar email filename: file:///data/user/0/org.thoughtcrime.securesms/files/avatars/a...@..."
),
arrayOf(
"An email and a number abc@def.com +155556789012345",
"An email and a number a...@... E164:<90596>"
),
arrayOf(
"__textsecure_group__!000102030405060708090a0b0c0d0e0f",
"GV1::***e0f"
),
arrayOf(
"A group id __textsecure_group__!000102030405060708090a0b0c0d0e1a surrounded with text",
"A group id GV1::***e1a surrounded with text"
),
arrayOf(
"__signal_group__v2__!0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
"GV2::***def"
),
arrayOf(
"A group v2 id __signal_group__v2__!23456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef01 surrounded with text",
"A group v2 id GV2::***f01 surrounded with text"
),
arrayOf(
"a37cb654-c9e0-4c1e-93df-3d11ca3c97f4",
"********-****-****-****-*********7f4"
),
arrayOf(
"A UUID a37cb654-c9e0-4c1e-93df-3d11ca3c97f4 surrounded with text",
"A UUID ********-****-****-****-*********7f4 surrounded with text"
),
arrayOf(
"An ACI:a37cb654-c9e0-4c1e-93df-3d11ca3c97f4 surrounded with text",
"An ACI:********-****-****-****-*********7f4 surrounded with text"
),
arrayOf(
"A PNI:a37cb654-c9e0-4c1e-93df-3d11ca3c97f4 surrounded with text",
"A PNI:<bdf84> surrounded with text"
),
arrayOf(
"JOB::a37cb654-c9e0-4c1e-93df-3d11ca3c97f4",
"JOB::a37cb654-c9e0-4c1e-93df-3d11ca3c97f4"
),
arrayOf(
"All patterns in a row __textsecure_group__!abcdefg1234567890 +123456789012345 abc@def.com a37cb654-c9e0-4c1e-93df-3d11ca3c97f4 nl.motorsport.com 192.168.1.1 with text after",
"All patterns in a row GV1::***890 E164:<78d5b> a...@... ********-****-****-****-*********7f4 ***.com ...ipv4... with text after"
),
arrayOf(
"java.net.UnknownServiceException: CLEARTEXT communication to nl.motorsport.com not permitted by network security policy",
"java.net.UnknownServiceException: CLEARTEXT communication to ***.com not permitted by network security policy"
),
arrayOf(
"nl.motorsport.com:443",
"***.com:443"
),
arrayOf(
"Failed to resolve chat.signal.org using . Continuing.",
"Failed to resolve chat.signal.org using . Continuing."
),
arrayOf(
" Caused by: java.io.IOException: unexpected end of stream on Connection{storage.signal.org:443, proxy=DIRECT hostAddress=storage.signal.org/142.251.32.211:443 cipherSuite=TLS_AES_128_GCM_SHA256 protocol=http/1.1}",
" Caused by: java.io.IOException: unexpected end of stream on Connection{storage.signal.org:443, proxy=DIRECT hostAddress=storage.signal.org/...ipv4...:443 cipherSuite=TLS_AES_128_GCM_SHA256 protocol=http/1.1}"
),
arrayOf(
"192.168.1.1",
"...ipv4..."
),
arrayOf(
"255.255.255.255",
"...ipv4..."
),
arrayOf(
"Text before 255.255.255.255 text after",
"Text before ...ipv4... text after"
),
arrayOf(
"Not an ipv4 3.141",
"Not an ipv4 3.141"
),
arrayOf(
"A Call Link Root Key BCDF-FGHK-MNPQ-RSTX-ZRQH-BCDF-FGHM-STXZ",
"A Call Link Root Key BCDF-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX-XXXX"
),
arrayOf(
"Not a Call Link Root Key (Invalid Characters) BCAF-FGHK-MNPQ-RSTX-ZRQH-BCDF-FGHM-STXZ",
"Not a Call Link Root Key (Invalid Characters) BCAF-FGHK-MNPQ-RSTX-ZRQH-BCDF-FGHM-STXZ"
),
arrayOf(
"Not a Call Link Root Key (Missing Quartet) BCAF-FGHK-MNPQ-RSTX-ZRQH-BCDF-STXZ",
"Not a Call Link Root Key (Missing Quartet) BCAF-FGHK-MNPQ-RSTX-ZRQH-BCDF-STXZ"
),
arrayOf(
"A Call Link Room ID 905db82618b907f9ceaf8f12cb65f061ffc187f7df747cb3f38d5281f7c686be",
"A Call Link Room ID *************************************************************6be"
),
arrayOf(
"Not a Call Link Room ID 905db82618b907f9ceaf8f12cb65f061ffc187f7df747cb3f38d5281f7c686b",
"Not a Call Link Room ID 905db82618b907f9ceaf8f12cb65f061ffc187f7df747cb3f38d5281f7c686b"
),
arrayOf(
"2345:0425:2CA1:0000:0000:0567:5673:23b5",
"...ipv6..."
),
arrayOf(
"2345:425:2CA1:0000:0000:567:5673:23b5",
"...ipv6..."
),
arrayOf(
"2345:0425:2CA1:0:0:0567:5673:23b5",
"...ipv6..."
),
arrayOf(
"2345:0425:2CA1::0567:5673:23b5",
"...ipv6..."
),
arrayOf(
"FF01:0:0:0:0:0:0:1",
"...ipv6..."
),
arrayOf(
"2001:db8::a3",
"...ipv6..."
),
arrayOf(
"text before 2345:0425:2CA1:0000:0000:0567:5673:23b5 text after",
"text before ...ipv6... text after"
),
arrayOf(
"Recipient::1",
"Recipient::1"
),
arrayOf(
"Recipient::123",
"Recipient::123"
),
arrayOf(
"url with text before https://example.com/v1/endpoint;asdf123%20$[]?asdf&asdf#asdf and stuff afterwards",
"url with text before https://***.com/*** and stuff afterwards"
),
arrayOf(
"https://signal.org/v1/endpoint",
"https://signal.org/v1/endpoint"
),
arrayOf(
"https://cdn3.signal.org/v1/endpoint",
"https://***.org/***"
),
arrayOf(
"https://debuglogs.org/android/7.47.2/2b5ccf4e3e58e44f12b3c92cfd5b526a2432f1dd0f81c8f89dededb176f1122d",
"https://debuglogs.org/android/7.47.2/2b5ccf4e3e58e44f12b3c92cfd5b526a2432f1dd0f81c8f89dededb176f1122d"
)
)
}
}
}

View file

@ -0,0 +1,146 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.junit.Assert.assertEquals
import org.junit.Test
import org.signal.core.util.readFully
import org.signal.core.util.readNBytesOrThrow
class LimitedInputStreamTest {
@Test
fun `when I fully read the stream via a buffer, I should only get maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val data = inputStream.readFully()
assertEquals(75, data.size)
}
@Test
fun `when I fully read the stream via a buffer with no limit, I should get all bytes`() {
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
val data = inputStream.readFully()
assertEquals(100, data.size)
}
@Test
fun `when I fully read the stream one byte at a time, I should only get maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertEquals(75, count)
}
@Test
fun `when I fully read the stream one byte at a time with no limit, I should only get maxBytes`() {
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertEquals(100, count)
}
@Test
fun `when I skip past the maxBytes, I should get -1`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100)
val read = inputStream.read()
assertEquals(75, skipCount)
assertEquals(-1, read)
}
@Test
fun `when I skip, I should still truncate correctly afterwards`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(50)
val data = inputStream.readFully()
assertEquals(50, skipCount)
assertEquals(25, data.size)
}
@Test
fun `when I skip more than maxBytes, I only skip maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val skipCount = inputStream.skip(100)
assertEquals(75, skipCount)
}
@Test
fun `when I finish reading the stream, leftoverStream gives me the rest`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readFully()
val truncatedBytes = inputStream.leftoverStream().readFully()
assertEquals(25, truncatedBytes.size)
}
@Test
fun `when call leftoverStream on a stream with no limit, it returns an empty array`() {
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
inputStream.readFully()
val truncatedBytes = inputStream.leftoverStream().readFully()
assertEquals(0, truncatedBytes.size)
}
@Test
fun `when I call available, it should respect the maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
val available = inputStream.available()
assertEquals(75, available)
}
@Test
fun `when I call available with no limit, it should return the full length`() {
val inputStream = LimitedInputStream.withoutLimits(ByteArray(100).inputStream())
val available = inputStream.available()
assertEquals(100, available)
}
@Test
fun `when I call available after reading some bytes, it should respect the maxBytes`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.readNBytesOrThrow(50)
val available = inputStream.available()
assertEquals(25, available)
}
@Test
fun `when I mark and reset, it should jump back to the correct position`() {
val inputStream = LimitedInputStream(ByteArray(100).inputStream(), maxBytes = 75)
inputStream.mark(100)
inputStream.readNBytesOrThrow(10)
inputStream.reset()
val data = inputStream.readFully()
assertEquals(75, data.size)
}
}

View file

@ -0,0 +1,56 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import org.signal.core.util.readFully
import java.io.InputStream
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.random.Random
class MacInputStreamTest {
@Test
fun `stream mac matches normal mac when reading via buffer`() {
testMacEquality { inputStream ->
inputStream.readFully()
}
}
@Test
fun `stream mac matches normal mac when reading one byte at a time`() {
testMacEquality { inputStream ->
var lastRead = inputStream.read()
while (lastRead != -1) {
lastRead = inputStream.read()
}
}
}
private fun testMacEquality(read: (InputStream) -> Unit) {
val data = Random.nextBytes(1_000)
val key = Random.nextBytes(32)
val mac1 = Mac.getInstance("HmacSHA256").apply {
init(SecretKeySpec(key, "HmacSHA256"))
}
val mac2 = Mac.getInstance("HmacSHA256").apply {
init(SecretKeySpec(key, "HmacSHA256"))
}
val expectedMac = mac1.doFinal(data)
val actualMac = MacInputStream(data.inputStream(), mac2).use { stream ->
read(stream)
stream.mac.doFinal()
}
assertArrayEquals(expectedMac, actualMac)
}
}

View file

@ -0,0 +1,56 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import org.signal.core.util.StreamUtil
import java.io.ByteArrayOutputStream
import java.io.OutputStream
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.random.Random
class MacOutputStreamTest {
@Test
fun `stream mac matches normal mac when writing via buffer`() {
testMacEquality { data, outputStream ->
StreamUtil.copy(data.inputStream(), outputStream)
}
}
@Test
fun `stream mac matches normal mac when writing one byte at a time`() {
testMacEquality { data, outputStream ->
for (byte in data) {
outputStream.write(byte.toInt())
}
}
}
private fun testMacEquality(write: (ByteArray, OutputStream) -> Unit) {
val data = Random.nextBytes(1_000)
val key = Random.nextBytes(32)
val mac1 = Mac.getInstance("HmacSHA256").apply {
init(SecretKeySpec(key, "HmacSHA256"))
}
val mac2 = Mac.getInstance("HmacSHA256").apply {
init(SecretKeySpec(key, "HmacSHA256"))
}
val expectedMac = mac1.doFinal(data)
val actualMac = MacOutputStream(ByteArrayOutputStream(), mac2).use { stream ->
write(data, stream)
stream.mac.doFinal()
}
assertArrayEquals(expectedMac, actualMac)
}
}

View file

@ -0,0 +1,106 @@
package org.signal.core.util.stream
import org.junit.Assert.assertEquals
import org.junit.Test
import org.signal.core.util.readFully
class TailerInputStreamTest {
@Test
fun `when I provide an incomplete stream and a known bytesLength, I can read the stream until bytesLength is reached`() {
var currentBytesLength = 0
val inputStream = TailerInputStream(
streamFactory = {
currentBytesLength += 10
ByteArray(currentBytesLength).inputStream()
},
bytesLength = 50
)
val data = inputStream.readFully()
assertEquals(50, data.size)
}
@Test
fun `when I provide an incomplete stream and a known bytesLength, I can read the stream one byte at a time until bytesLength is reached`() {
var currentBytesLength = 0
val inputStream = TailerInputStream(
streamFactory = {
currentBytesLength += 10
ByteArray(currentBytesLength).inputStream()
},
bytesLength = 20
)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertEquals(20, count)
}
@Test
fun `when I provide a complete stream and a known bytesLength, I can read the stream until bytesLength is reached`() {
val inputStream = TailerInputStream(
streamFactory = { ByteArray(50).inputStream() },
bytesLength = 50
)
val data = inputStream.readFully()
assertEquals(50, data.size)
}
@Test
fun `when I provide a complete stream and a known bytesLength, I can read the stream one byte at a time until bytesLength is reached`() {
val inputStream = TailerInputStream(
streamFactory = { ByteArray(20).inputStream() },
bytesLength = 20
)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertEquals(20, count)
}
@Test
fun `when I skip bytes, I still read until the end of bytesLength`() {
var currentBytesLength = 0
val inputStream = TailerInputStream(
streamFactory = {
currentBytesLength += 10
ByteArray(currentBytesLength).inputStream()
},
bytesLength = 50
)
inputStream.skip(5)
val data = inputStream.readFully()
assertEquals(45, data.size)
}
@Test
fun `when I skip more bytes than available, I can still read until the end of bytesLength`() {
var currentBytesLength = 0
val inputStream = TailerInputStream(
streamFactory = {
currentBytesLength += 10
ByteArray(currentBytesLength).inputStream()
},
bytesLength = 50
)
inputStream.skip(15)
val data = inputStream.readFully()
assertEquals(40, data.size)
}
}

View file

@ -0,0 +1,140 @@
/*
* Copyright 2023 Signal Messenger, LLC
* SPDX-License-Identifier: AGPL-3.0-only
*/
package org.signal.core.util.stream
import assertk.assertThat
import assertk.assertions.isEqualTo
import org.junit.Test
import org.signal.core.util.readFully
import kotlin.math.min
import kotlin.random.Random
class TrimmingInputStreamTest {
@Test
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(75)
assertThat(data).isEqualTo(initialData.copyOfRange(0, 75))
}
@Test
fun `when I fully read the stream via a buffer, I should exclude the last trimSize bytes - many sizes`() {
for (i in 1..100) {
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
val trimSize = min(arraySize, Random.nextInt(1024))
val initialData = testData(arraySize)
val innerStream = initialData.inputStream()
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(arraySize - trimSize)
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
}
}
@Test
fun `when I fully read the stream via a buffer with drain set, I should exclude the last trimSize bytes but still drain the remaining stream - many sizes`() {
for (i in 1..100) {
val arraySize = Random.nextInt(1024, 2 * 1024 * 1024)
val trimSize = min(arraySize, Random.nextInt(1024))
val initialData = testData(arraySize)
val innerStream = initialData.inputStream()
val inputStream = TrimmingInputStream(innerStream, trimSize = trimSize, drain = true)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(arraySize - trimSize)
assertThat(data).isEqualTo(initialData.copyOfRange(0, arraySize - trimSize))
assertThat(innerStream.available()).isEqualTo(0)
}
}
@Test
fun `when I fully read the stream and the trimSize is greater than the stream length, I should get zero bytes`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 200)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(0)
}
@Test
fun `when I fully read the stream via a buffer with no trimSize, I should get all bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
val data = inputStream.readFully()
assertThat(data.size).isEqualTo(100)
}
@Test
fun `when I fully read the stream one byte at a time, I should exclude the last trimSize bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertThat(count).isEqualTo(75)
}
@Test
fun `when I fully read the stream one byte at a time with no trimSize, I should get all bytes`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 0)
var count = 0
var lastRead = inputStream.read()
while (lastRead != -1) {
count++
lastRead = inputStream.read()
}
assertThat(count).isEqualTo(100)
}
@Test
fun `when I skip past the the trimSize, I should get -1`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
val skipCount = inputStream.skip(100)
val read = inputStream.read()
assertThat(skipCount).isEqualTo(75)
assertThat(read).isEqualTo(-1)
}
@Test
fun `when I skip, I should still truncate correctly afterwards`() {
val inputStream = TrimmingInputStream(ByteArray(100).inputStream(), trimSize = 25)
val skipCount = inputStream.skip(50)
val data = inputStream.readFully()
assertThat(skipCount).isEqualTo(50)
assertThat(data.size).isEqualTo(25)
}
@Test
fun `when I skip more than the remaining bytes, I still respect trimSize`() {
val initialData = testData(100)
val inputStream = TrimmingInputStream(initialData.inputStream(), trimSize = 25)
val skipCount = inputStream.skip(100)
assertThat(skipCount).isEqualTo(75)
}
private fun testData(length: Int): ByteArray {
return ByteArray(length) { (it % 0xFF).toByte() }
}
}