diff --git a/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/ExtractImage.kt b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/ExtractImage.kt new file mode 100644 index 000000000..59fc1a5ed --- /dev/null +++ b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/ExtractImage.kt @@ -0,0 +1,156 @@ +package com.topjohnwu.magisk.core.tasks + +import com.topjohnwu.magisk.core.di.ServiceLocator +import com.topjohnwu.magisk.core.utils.HttpFileChannel +import okio.buffer +import okio.inflate +import okio.sink +import org.apache.commons.compress.archivers.zip.ZipArchiveEntry +import org.apache.commons.compress.archivers.zip.ZipFile +import org.apache.commons.compress.archivers.zip.ZipMethod +import java.io.File +import java.io.IOException +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption + +class ExtractImage( + private val url: String, + private val console: MutableList, + private val logs: MutableList, +) { + @Throws(IOException::class) + fun start(outFile: File) { + logs.add("Downloading from: $url") + + val channel = HttpFileChannel(ServiceLocator.okhttp, url) + ZipFile.builder() + .setSeekableByteChannel(channel) + .setIgnoreLocalFileHeader(true) + .get().use { zipFile -> + val payload = zipFile.getEntry("payload.bin") + if (payload != null) { + console.add("- Processing as OTA package") + + zipFile.getEntry("META-INF/com/android/metadata")?.let { entry -> + zipFile.getInputStream(entry).use { + val meta = it.bufferedReader().readText() + logs.add(meta) + + console.add("- OTA metadata:") + meta.lines().forEach { line -> + if (line.startsWith("post-")) { + console.add(" ${line.substringAfter('-')}") + } + } + } + } + zipFile.getRawInputStream(payload) + extractFromOTAPackage(payload, channel, outFile) + } else { + extractFromFactoryImage(zipFile, channel, outFile) + } + } + } + + @Throws(IOException::class) + private fun extractFromOTAPackage( + payload: ZipArchiveEntry, + channel: HttpFileChannel, + outFile: File, + ) { + if (payload.method != ZipMethod.STORED.code) { + throw IOException("payload.bin is compressed, expected STORED method") + } + + channel.slice(payload.dataOffset, payload.size).use { payloadChannel -> + Payload(payloadChannel).extract(outFile, { console.add(it) }, { logs.add(it) }) + } + } + + @Throws(IOException::class) + private fun extractFromFactoryImage(zipFile: ZipFile, channel: HttpFileChannel, outFile: File) { + console.add("- Processing as factory image package") + + findBootImageZipEntry(zipFile)?.let { entry -> + return extractImageFile(zipFile, entry, channel, outFile) + } + + val imageZipEntry = zipFile.entries.asSequence().find { entry -> + val fileName = entry.name.substringAfterLast('/') + fileName.startsWith("image-") && fileName.endsWith(".zip") + } + if (imageZipEntry != null) { + zipFile.getRawInputStream(imageZipEntry) + return extractFromInnerImageZip(imageZipEntry, channel, outFile) + } + + throw IOException("inner image ZIP not found in factory image package") + } + + private fun findBootImageZipEntry(zipFile: ZipFile): ZipArchiveEntry? { + return zipFile.entries.asSequence().find { it.name == "init_boot.img" } + ?: zipFile.entries.asSequence().find { it.name == "boot.img" } + } + + @Throws(IOException::class) + private fun extractFromInnerImageZip( + entry: ZipArchiveEntry, + channel: HttpFileChannel, + outFile: File + ) { + logs.add("Found inner image ZIP: ${entry.name}") + + if (entry.method != ZipMethod.STORED.code) { + throw IOException("image ZIP is compressed, expected STORED method") + } + + channel.slice(entry.dataOffset, entry.size).use { innerZipChannel -> + ZipFile.builder() + .setSeekableByteChannel(innerZipChannel) + .setIgnoreLocalFileHeader(true) + .get().use { innerZipFile -> + val targetEntry = findBootImageZipEntry(innerZipFile) + ?: throw IOException("boot image not found in inner image ZIP") + return extractImageFile(innerZipFile, targetEntry, innerZipChannel, outFile) + } + } + } + + @Throws(IOException::class) + private fun extractImageFile( + zipFile: ZipFile, + entry: ZipArchiveEntry, + channel: HttpFileChannel, + outFile: File, + ) { + console.add("- Found boot image entry: ${entry.name} (${entry.size} bytes)") + console.add("- Downloading") + + zipFile.getRawInputStream(entry) + when (entry.method) { + ZipMethod.STORED.code -> { + FileChannel.open( + outFile.toPath(), + StandardOpenOption.CREATE, + StandardOpenOption.WRITE, + StandardOpenOption.READ, + StandardOpenOption.TRUNCATE_EXISTING + ).use { fileChannel -> + val mapped = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, entry.size) + val sourceChannel = channel.slice(entry.dataOffset, entry.size) + sourceChannel.read(mapped) + } + } + + ZipMethod.DEFLATED.code -> { + channel.streamRead(entry.dataOffset, entry.size).inflate().use { source -> + outFile.sink().buffer().use { sink -> + sink.writeAll(source) + } + } + } + + else -> throw IOException("unsupported method: ${entry.method}") + } + } +} diff --git a/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/MagiskInstaller.kt b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/MagiskInstaller.kt index 94e0034c7..23d59640b 100644 --- a/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/MagiskInstaller.kt +++ b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/MagiskInstaller.kt @@ -540,7 +540,7 @@ abstract class MagiskInstallImpl protected constructor( // Download image from url try { srcBoot = installDir.getChildFile("boot.img") - //todo + ExtractImage(url, console, logs).start(srcBoot) } catch (e: IOException) { console.add("! Error: " + e.message) Timber.e(e) diff --git a/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/Payload.kt b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/Payload.kt new file mode 100644 index 000000000..e6e8d4156 --- /dev/null +++ b/app/core/src/main/java/com/topjohnwu/magisk/core/tasks/Payload.kt @@ -0,0 +1,177 @@ +package com.topjohnwu.magisk.core.tasks + +import chromeos_update_engine.UpdateMetadata.DeltaArchiveManifest +import chromeos_update_engine.UpdateMetadata.InstallOperation +import chromeos_update_engine.UpdateMetadata.PartitionUpdate +import com.topjohnwu.magisk.core.utils.HttpFileChannel +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream +import org.apache.commons.compress.compressors.xz.XZCompressorInputStream +import java.io.File +import java.io.IOException +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption +import java.security.MessageDigest + +class Payload(private val channel: HttpFileChannel) { + private val manifest: DeltaArchiveManifest + private var dataBase = 0L + + init { + manifest = readPayloadHeader() + } + + @Throws(IOException::class) + fun extract(outputFile: File, console: (String) -> Unit, logger: (String) -> Unit) { + val partition = findPartition() + console("- Found partition ${partition.partitionName}") + + val actualHash = extractPartition(outputFile, partition, console) + + if (!partition.newPartitionInfo.hasHash()) { + logger("Hash verification skipped") + return + } + + fun toHex(bytes: ByteArray) = bytes.joinToString("") { "%02x".format(it) } + + val expectedHash = partition.newPartitionInfo.hash.toByteArray() + if (!expectedHash.contentEquals(actualHash)) { + throw IOException( + "Hash mismatch, expected ${toHex(expectedHash)}, but got ${toHex(actualHash)}" + ) + } + logger("Hash verification passed") + } + + @Throws(IOException::class) + private fun readPayloadHeader(): DeltaArchiveManifest { + // Read magic + val magicBuffer = ByteBuffer.allocate(4) + channel.read(magicBuffer) + magicBuffer.flip() + val magic = String(magicBuffer.array()) + if (magic != "CrAU") { + throw IOException("Invalid payload: invalid magic") + } + + // Read version + val versionBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN) + channel.read(versionBuffer) + versionBuffer.flip() + val version = versionBuffer.long + if (version != 2L) { + throw IOException("Invalid payload: unsupported version: $version") + } + + // Read manifest length + val manifestLenBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN) + channel.read(manifestLenBuffer) + manifestLenBuffer.flip() + val manifestLen = manifestLenBuffer.long.toInt() + if (manifestLen == 0) { + throw IOException("Invalid payload: manifest length is zero") + } + + // Read manifest signature length + val manifestSigLenBuffer = ByteBuffer.allocate(4).order(ByteOrder.BIG_ENDIAN) + channel.read(manifestSigLenBuffer) + manifestSigLenBuffer.flip() + val manifestSigLen = manifestSigLenBuffer.int + if (manifestSigLen == 0) { + throw IOException("Invalid payload: manifest signature length is zero") + } + + // Read manifest + val manifestBuffer = ByteBuffer.allocate(manifestLen) + channel.read(manifestBuffer) + manifestBuffer.flip() + val manifest = DeltaArchiveManifest.parseFrom(manifestBuffer.array()) + + // Skip manifest signature + channel.position(channel.position() + manifestSigLen) + + dataBase = channel.position() + + return manifest + } + + @Throws(IOException::class) + private fun findPartition(): PartitionUpdate { + return manifest.partitionsList.find { it.partitionName == "init_boot" } + ?: manifest.partitionsList.find { it.partitionName == "boot" } + ?: throw IOException("boot partition not found in payload") + } + + @Throws(IOException::class) + private fun extractPartition( + outputFile: File, + partition: PartitionUpdate, + console: (String) -> Unit, + ): ByteArray { + FileChannel.open( + outputFile.toPath(), + StandardOpenOption.CREATE, + StandardOpenOption.WRITE, + StandardOpenOption.READ, + StandardOpenOption.TRUNCATE_EXISTING + ).use { outChannel -> + val size = partition.newPartitionInfo.size + outChannel.write(ByteBuffer.allocate(1), size - 1) + + val count = partition.operationsCount + partition.operationsList.forEachIndexed { index, operation -> + if (index % 5 == 0 || index == count - 1) { + console("- Downloading ${index + 1}/$count") + } + processOperation(outChannel, operation) + } + + val digest = MessageDigest.getInstance("SHA-256") + val buffer = outChannel.map(FileChannel.MapMode.READ_WRITE, 0, size) + digest.update(buffer) + return digest.digest() + } + } + + @Throws(IOException::class) + private fun processOperation(outChannel: FileChannel, operation: InstallOperation) { + val dataType = operation.getType() + if (dataType == InstallOperation.Type.ZERO) { + return + } + + val dataBuffer = ByteBuffer.allocate(operation.getDataLength().toInt()) + channel.read(dataBuffer, dataBase + operation.getDataOffset()) + dataBuffer.flip() + + val dstExtent = operation.getDstExtents(0) + val outOffset = dstExtent.getStartBlock() * manifest.getBlockSize() + + when (dataType) { + InstallOperation.Type.REPLACE -> { + outChannel.write(dataBuffer, outOffset) + } + + InstallOperation.Type.REPLACE_BZ, InstallOperation.Type.REPLACE_XZ -> { + val inputStream = dataBuffer.array().inputStream() + if (dataType == InstallOperation.Type.REPLACE_BZ) { + BZip2CompressorInputStream(inputStream) + } else { + XZCompressorInputStream(inputStream) + }.use { decompressor -> + val bytes = ByteArray(8192) + var bytesRead: Int + var bytesWritten = 0 + while (decompressor.read(bytes).also { bytesRead = it } != -1) { + val buffer = ByteBuffer.wrap(bytes, 0, bytesRead) + bytesWritten += outChannel.write(buffer, outOffset + bytesWritten) + } + } + } + + else -> throw IOException("Unsupported operation type: $dataType") + } + } +} diff --git a/app/core/src/main/java/com/topjohnwu/magisk/core/utils/HttpFileChannel.java b/app/core/src/main/java/com/topjohnwu/magisk/core/utils/HttpFileChannel.java new file mode 100644 index 000000000..9cc9a0a30 --- /dev/null +++ b/app/core/src/main/java/com/topjohnwu/magisk/core/utils/HttpFileChannel.java @@ -0,0 +1,244 @@ +package com.topjohnwu.magisk.core.utils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NonWritableChannelException; +import java.nio.channels.SeekableByteChannel; + +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okio.BufferedSource; + +public class HttpFileChannel implements SeekableByteChannel { + private static final int RANDOM_READ_CACHE_SIZE = 16 * 1024; + private static final int SEQ_READ_CACHE_SIZE = 1024 * 1024; + private static final int SEQ_READ_THRESHOLD = 1024; + private static final int DIRECT_READ_THRESHOLD = 512 * 1024; + + private final OkHttpClient client; + private final String url; + private final long startOffset; + private final long size; + + private long position = 0; + private boolean open = true; + + private byte[] cache = null; + private long cacheStart = -1; + + public HttpFileChannel(OkHttpClient client, String url, long startOffset, long size) { + this.client = client; + this.url = url; + this.startOffset = startOffset; + this.size = size; + } + + public HttpFileChannel(OkHttpClient client, String url) throws IOException { + this(client, url, 0, fetchTotalSize(client, url)); + } + + private static long fetchTotalSize(OkHttpClient client, String url) throws IOException { + var request = new Request.Builder().url(url).head().build(); + try (var response = client.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new IOException("Failed to connect to URL: " + response); + } + var contentLength = response.header("Content-Length"); + if (contentLength == null) { + throw new IOException("Could not determine file size."); + } + var acceptRanges = response.header("Accept-Ranges"); + if (acceptRanges == null || !acceptRanges.equalsIgnoreCase("bytes")) { + throw new IOException("Server does not support byte ranges: " + response); + } + return Long.parseLong(contentLength); + } + } + + public HttpFileChannel slice(long offset, long sliceSize) { + if (offset == 0 && sliceSize == size) { + return this; + } + if (offset < 0 || sliceSize <= 0 || offset + sliceSize >= size) { + throw new IllegalArgumentException("Invalid slice parameters"); + } + return new HttpFileChannel(client, url, startOffset + offset, sliceSize); + } + + @Override + public int read(ByteBuffer dst) throws IOException { + var bytesRead = read(dst, position); + position += bytesRead; + return bytesRead; + } + + public int read(ByteBuffer dst, long position) throws IOException { + if (!open) throw new ClosedChannelException(); + if (position < 0) { + throw new IllegalArgumentException("Position out of bounds: " + position); + } + if (position >= size) return -1; + + int requestSize = dst.remaining(); + if (requestSize == 0) return 0; + + if (requestSize > DIRECT_READ_THRESHOLD) { + return handleLargeRead(dst, position); + } + + int totalBytesRead = 0; + if (isCacheHit(position, 1)) { + int bytesFromCache = readFromCache(dst, position); + totalBytesRead += bytesFromCache; + position += bytesFromCache; + } + + if (dst.hasRemaining() && position < size) { + loadCache(position, requestSize); + if (isCacheHit(position, dst.remaining())) { + totalBytesRead += readFromCache(dst, position); + } else { + totalBytesRead += readDirectly(dst, position); + } + } + + return totalBytesRead; + } + + private int handleLargeRead(ByteBuffer dst, long position) throws IOException { + int bytesFromCache = 0; + if (isCacheHit(position, 1)) { + bytesFromCache = readFromCache(dst, position); + position += bytesFromCache; + } + + if (dst.hasRemaining() && position < size) { + int directBytesRead = readDirectly(dst, position); + return bytesFromCache + directBytesRead; + } else { + return bytesFromCache; + } + } + + private void loadCache(long requestPos, int requestSize) throws IOException { + int cacheSize; + long cacheStart; + + var lastCacheEnd = cache != null ? this.cacheStart + cache.length : -1; + if (requestSize > SEQ_READ_THRESHOLD || lastCacheEnd == requestPos) { + cacheSize = SEQ_READ_CACHE_SIZE; + cacheStart = requestPos; + } else { + cacheSize = RANDOM_READ_CACHE_SIZE; + cacheStart = Math.max(0, requestPos - cacheSize / 2); + } + + loadCacheAt(cacheStart, cacheSize); + } + + private void loadCacheAt(long cacheStart, int cacheSize) throws IOException { + long maxEnd = Math.min(cacheStart + cacheSize, size); + cacheStart = Math.max(0, maxEnd - cacheSize); + + var buffer = ByteBuffer.allocate((int) (maxEnd - cacheStart)); + var bytesRead = readDirectly(buffer, cacheStart); + if (bytesRead != buffer.capacity()) { + throw new IOException("Failed to fill cache."); + } + + cache = buffer.array(); + this.cacheStart = cacheStart; + + } + + private boolean isCacheHit(long pos, int bytesToRead) { + if (cache == null) return false; + long cacheEnd = cacheStart + cache.length; + long readEnd = Math.min(pos + bytesToRead, size); + return pos >= cacheStart && readEnd <= cacheEnd; + } + + private int readFromCache(ByteBuffer dst, long position) { + long relativePos = position - cacheStart; + int available = (int) Math.min(dst.remaining(), cache.length - relativePos); + + dst.put(cache, (int) relativePos, available); + + return available; + } + + private int readDirectly(ByteBuffer dst, long position) throws IOException { + try (var source = streamRead(position, dst.remaining()); + var channel = Channels.newChannel(source.inputStream())) { + int totalBytesRead = 0; + while (true) { + int bytesRead = channel.read(dst); + if (bytesRead <= 0) { + break; + } + totalBytesRead += bytesRead; + } + + return totalBytesRead; + } + } + + public BufferedSource streamRead(long position, long length) throws IOException { + long endPosition = Math.min(position + length, size) + startOffset; + + var request = new Request.Builder() + .url(url) + .header("Range", "bytes=" + (startOffset + position) + "-" + (endPosition - 1)) + .build(); + + var response = client.newCall(request).execute(); + if (response.code() != 206) { + response.close(); + throw new IOException("Unexpected response code " + response.code()); + } + return response.body().source(); + } + + @Override + public long position() { + return position; + } + + @Override + public SeekableByteChannel position(long newPosition) throws IOException { + if (!open) throw new ClosedChannelException(); + if (newPosition < 0) { + throw new IllegalArgumentException("Position out of bounds: " + newPosition); + } + position = newPosition; + return this; + } + + @Override + public long size() { + return size; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + cache = null; + } + + @Override + public int write(ByteBuffer src) { + throw new NonWritableChannelException(); + } + + @Override + public SeekableByteChannel truncate(long size) { + throw new NonWritableChannelException(); + } +}