app: support download image and patch 2/2

This commit is contained in:
vvb2060
2025-09-29 22:30:13 +08:00
committed by John Wu
parent e70e8088ad
commit b9d21071fc
4 changed files with 578 additions and 1 deletions
@@ -0,0 +1,156 @@
package com.topjohnwu.magisk.core.tasks
import com.topjohnwu.magisk.core.di.ServiceLocator
import com.topjohnwu.magisk.core.utils.HttpFileChannel
import okio.buffer
import okio.inflate
import okio.sink
import org.apache.commons.compress.archivers.zip.ZipArchiveEntry
import org.apache.commons.compress.archivers.zip.ZipFile
import org.apache.commons.compress.archivers.zip.ZipMethod
import java.io.File
import java.io.IOException
import java.nio.channels.FileChannel
import java.nio.file.StandardOpenOption
class ExtractImage(
private val url: String,
private val console: MutableList<String>,
private val logs: MutableList<String>,
) {
@Throws(IOException::class)
fun start(outFile: File) {
logs.add("Downloading from: $url")
val channel = HttpFileChannel(ServiceLocator.okhttp, url)
ZipFile.builder()
.setSeekableByteChannel(channel)
.setIgnoreLocalFileHeader(true)
.get().use { zipFile ->
val payload = zipFile.getEntry("payload.bin")
if (payload != null) {
console.add("- Processing as OTA package")
zipFile.getEntry("META-INF/com/android/metadata")?.let { entry ->
zipFile.getInputStream(entry).use {
val meta = it.bufferedReader().readText()
logs.add(meta)
console.add("- OTA metadata:")
meta.lines().forEach { line ->
if (line.startsWith("post-")) {
console.add(" ${line.substringAfter('-')}")
}
}
}
}
zipFile.getRawInputStream(payload)
extractFromOTAPackage(payload, channel, outFile)
} else {
extractFromFactoryImage(zipFile, channel, outFile)
}
}
}
@Throws(IOException::class)
private fun extractFromOTAPackage(
payload: ZipArchiveEntry,
channel: HttpFileChannel,
outFile: File,
) {
if (payload.method != ZipMethod.STORED.code) {
throw IOException("payload.bin is compressed, expected STORED method")
}
channel.slice(payload.dataOffset, payload.size).use { payloadChannel ->
Payload(payloadChannel).extract(outFile, { console.add(it) }, { logs.add(it) })
}
}
@Throws(IOException::class)
private fun extractFromFactoryImage(zipFile: ZipFile, channel: HttpFileChannel, outFile: File) {
console.add("- Processing as factory image package")
findBootImageZipEntry(zipFile)?.let { entry ->
return extractImageFile(zipFile, entry, channel, outFile)
}
val imageZipEntry = zipFile.entries.asSequence().find { entry ->
val fileName = entry.name.substringAfterLast('/')
fileName.startsWith("image-") && fileName.endsWith(".zip")
}
if (imageZipEntry != null) {
zipFile.getRawInputStream(imageZipEntry)
return extractFromInnerImageZip(imageZipEntry, channel, outFile)
}
throw IOException("inner image ZIP not found in factory image package")
}
private fun findBootImageZipEntry(zipFile: ZipFile): ZipArchiveEntry? {
return zipFile.entries.asSequence().find { it.name == "init_boot.img" }
?: zipFile.entries.asSequence().find { it.name == "boot.img" }
}
@Throws(IOException::class)
private fun extractFromInnerImageZip(
entry: ZipArchiveEntry,
channel: HttpFileChannel,
outFile: File
) {
logs.add("Found inner image ZIP: ${entry.name}")
if (entry.method != ZipMethod.STORED.code) {
throw IOException("image ZIP is compressed, expected STORED method")
}
channel.slice(entry.dataOffset, entry.size).use { innerZipChannel ->
ZipFile.builder()
.setSeekableByteChannel(innerZipChannel)
.setIgnoreLocalFileHeader(true)
.get().use { innerZipFile ->
val targetEntry = findBootImageZipEntry(innerZipFile)
?: throw IOException("boot image not found in inner image ZIP")
return extractImageFile(innerZipFile, targetEntry, innerZipChannel, outFile)
}
}
}
@Throws(IOException::class)
private fun extractImageFile(
zipFile: ZipFile,
entry: ZipArchiveEntry,
channel: HttpFileChannel,
outFile: File,
) {
console.add("- Found boot image entry: ${entry.name} (${entry.size} bytes)")
console.add("- Downloading")
zipFile.getRawInputStream(entry)
when (entry.method) {
ZipMethod.STORED.code -> {
FileChannel.open(
outFile.toPath(),
StandardOpenOption.CREATE,
StandardOpenOption.WRITE,
StandardOpenOption.READ,
StandardOpenOption.TRUNCATE_EXISTING
).use { fileChannel ->
val mapped = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0, entry.size)
val sourceChannel = channel.slice(entry.dataOffset, entry.size)
sourceChannel.read(mapped)
}
}
ZipMethod.DEFLATED.code -> {
channel.streamRead(entry.dataOffset, entry.size).inflate().use { source ->
outFile.sink().buffer().use { sink ->
sink.writeAll(source)
}
}
}
else -> throw IOException("unsupported method: ${entry.method}")
}
}
}
@@ -540,7 +540,7 @@ abstract class MagiskInstallImpl protected constructor(
// Download image from url
try {
srcBoot = installDir.getChildFile("boot.img")
//todo
ExtractImage(url, console, logs).start(srcBoot)
} catch (e: IOException) {
console.add("! Error: " + e.message)
Timber.e(e)
@@ -0,0 +1,177 @@
package com.topjohnwu.magisk.core.tasks
import chromeos_update_engine.UpdateMetadata.DeltaArchiveManifest
import chromeos_update_engine.UpdateMetadata.InstallOperation
import chromeos_update_engine.UpdateMetadata.PartitionUpdate
import com.topjohnwu.magisk.core.utils.HttpFileChannel
import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream
import org.apache.commons.compress.compressors.xz.XZCompressorInputStream
import java.io.File
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.nio.channels.FileChannel
import java.nio.file.StandardOpenOption
import java.security.MessageDigest
class Payload(private val channel: HttpFileChannel) {
private val manifest: DeltaArchiveManifest
private var dataBase = 0L
init {
manifest = readPayloadHeader()
}
@Throws(IOException::class)
fun extract(outputFile: File, console: (String) -> Unit, logger: (String) -> Unit) {
val partition = findPartition()
console("- Found partition ${partition.partitionName}")
val actualHash = extractPartition(outputFile, partition, console)
if (!partition.newPartitionInfo.hasHash()) {
logger("Hash verification skipped")
return
}
fun toHex(bytes: ByteArray) = bytes.joinToString("") { "%02x".format(it) }
val expectedHash = partition.newPartitionInfo.hash.toByteArray()
if (!expectedHash.contentEquals(actualHash)) {
throw IOException(
"Hash mismatch, expected ${toHex(expectedHash)}, but got ${toHex(actualHash)}"
)
}
logger("Hash verification passed")
}
@Throws(IOException::class)
private fun readPayloadHeader(): DeltaArchiveManifest {
// Read magic
val magicBuffer = ByteBuffer.allocate(4)
channel.read(magicBuffer)
magicBuffer.flip()
val magic = String(magicBuffer.array())
if (magic != "CrAU") {
throw IOException("Invalid payload: invalid magic")
}
// Read version
val versionBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
channel.read(versionBuffer)
versionBuffer.flip()
val version = versionBuffer.long
if (version != 2L) {
throw IOException("Invalid payload: unsupported version: $version")
}
// Read manifest length
val manifestLenBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN)
channel.read(manifestLenBuffer)
manifestLenBuffer.flip()
val manifestLen = manifestLenBuffer.long.toInt()
if (manifestLen == 0) {
throw IOException("Invalid payload: manifest length is zero")
}
// Read manifest signature length
val manifestSigLenBuffer = ByteBuffer.allocate(4).order(ByteOrder.BIG_ENDIAN)
channel.read(manifestSigLenBuffer)
manifestSigLenBuffer.flip()
val manifestSigLen = manifestSigLenBuffer.int
if (manifestSigLen == 0) {
throw IOException("Invalid payload: manifest signature length is zero")
}
// Read manifest
val manifestBuffer = ByteBuffer.allocate(manifestLen)
channel.read(manifestBuffer)
manifestBuffer.flip()
val manifest = DeltaArchiveManifest.parseFrom(manifestBuffer.array())
// Skip manifest signature
channel.position(channel.position() + manifestSigLen)
dataBase = channel.position()
return manifest
}
@Throws(IOException::class)
private fun findPartition(): PartitionUpdate {
return manifest.partitionsList.find { it.partitionName == "init_boot" }
?: manifest.partitionsList.find { it.partitionName == "boot" }
?: throw IOException("boot partition not found in payload")
}
@Throws(IOException::class)
private fun extractPartition(
outputFile: File,
partition: PartitionUpdate,
console: (String) -> Unit,
): ByteArray {
FileChannel.open(
outputFile.toPath(),
StandardOpenOption.CREATE,
StandardOpenOption.WRITE,
StandardOpenOption.READ,
StandardOpenOption.TRUNCATE_EXISTING
).use { outChannel ->
val size = partition.newPartitionInfo.size
outChannel.write(ByteBuffer.allocate(1), size - 1)
val count = partition.operationsCount
partition.operationsList.forEachIndexed { index, operation ->
if (index % 5 == 0 || index == count - 1) {
console("- Downloading ${index + 1}/$count")
}
processOperation(outChannel, operation)
}
val digest = MessageDigest.getInstance("SHA-256")
val buffer = outChannel.map(FileChannel.MapMode.READ_WRITE, 0, size)
digest.update(buffer)
return digest.digest()
}
}
@Throws(IOException::class)
private fun processOperation(outChannel: FileChannel, operation: InstallOperation) {
val dataType = operation.getType()
if (dataType == InstallOperation.Type.ZERO) {
return
}
val dataBuffer = ByteBuffer.allocate(operation.getDataLength().toInt())
channel.read(dataBuffer, dataBase + operation.getDataOffset())
dataBuffer.flip()
val dstExtent = operation.getDstExtents(0)
val outOffset = dstExtent.getStartBlock() * manifest.getBlockSize()
when (dataType) {
InstallOperation.Type.REPLACE -> {
outChannel.write(dataBuffer, outOffset)
}
InstallOperation.Type.REPLACE_BZ, InstallOperation.Type.REPLACE_XZ -> {
val inputStream = dataBuffer.array().inputStream()
if (dataType == InstallOperation.Type.REPLACE_BZ) {
BZip2CompressorInputStream(inputStream)
} else {
XZCompressorInputStream(inputStream)
}.use { decompressor ->
val bytes = ByteArray(8192)
var bytesRead: Int
var bytesWritten = 0
while (decompressor.read(bytes).also { bytesRead = it } != -1) {
val buffer = ByteBuffer.wrap(bytes, 0, bytesRead)
bytesWritten += outChannel.write(buffer, outOffset + bytesWritten)
}
}
}
else -> throw IOException("Unsupported operation type: $dataType")
}
}
}
@@ -0,0 +1,244 @@
package com.topjohnwu.magisk.core.utils;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.NonWritableChannelException;
import java.nio.channels.SeekableByteChannel;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okio.BufferedSource;
public class HttpFileChannel implements SeekableByteChannel {
private static final int RANDOM_READ_CACHE_SIZE = 16 * 1024;
private static final int SEQ_READ_CACHE_SIZE = 1024 * 1024;
private static final int SEQ_READ_THRESHOLD = 1024;
private static final int DIRECT_READ_THRESHOLD = 512 * 1024;
private final OkHttpClient client;
private final String url;
private final long startOffset;
private final long size;
private long position = 0;
private boolean open = true;
private byte[] cache = null;
private long cacheStart = -1;
public HttpFileChannel(OkHttpClient client, String url, long startOffset, long size) {
this.client = client;
this.url = url;
this.startOffset = startOffset;
this.size = size;
}
public HttpFileChannel(OkHttpClient client, String url) throws IOException {
this(client, url, 0, fetchTotalSize(client, url));
}
private static long fetchTotalSize(OkHttpClient client, String url) throws IOException {
var request = new Request.Builder().url(url).head().build();
try (var response = client.newCall(request).execute()) {
if (!response.isSuccessful()) {
throw new IOException("Failed to connect to URL: " + response);
}
var contentLength = response.header("Content-Length");
if (contentLength == null) {
throw new IOException("Could not determine file size.");
}
var acceptRanges = response.header("Accept-Ranges");
if (acceptRanges == null || !acceptRanges.equalsIgnoreCase("bytes")) {
throw new IOException("Server does not support byte ranges: " + response);
}
return Long.parseLong(contentLength);
}
}
public HttpFileChannel slice(long offset, long sliceSize) {
if (offset == 0 && sliceSize == size) {
return this;
}
if (offset < 0 || sliceSize <= 0 || offset + sliceSize >= size) {
throw new IllegalArgumentException("Invalid slice parameters");
}
return new HttpFileChannel(client, url, startOffset + offset, sliceSize);
}
@Override
public int read(ByteBuffer dst) throws IOException {
var bytesRead = read(dst, position);
position += bytesRead;
return bytesRead;
}
public int read(ByteBuffer dst, long position) throws IOException {
if (!open) throw new ClosedChannelException();
if (position < 0) {
throw new IllegalArgumentException("Position out of bounds: " + position);
}
if (position >= size) return -1;
int requestSize = dst.remaining();
if (requestSize == 0) return 0;
if (requestSize > DIRECT_READ_THRESHOLD) {
return handleLargeRead(dst, position);
}
int totalBytesRead = 0;
if (isCacheHit(position, 1)) {
int bytesFromCache = readFromCache(dst, position);
totalBytesRead += bytesFromCache;
position += bytesFromCache;
}
if (dst.hasRemaining() && position < size) {
loadCache(position, requestSize);
if (isCacheHit(position, dst.remaining())) {
totalBytesRead += readFromCache(dst, position);
} else {
totalBytesRead += readDirectly(dst, position);
}
}
return totalBytesRead;
}
private int handleLargeRead(ByteBuffer dst, long position) throws IOException {
int bytesFromCache = 0;
if (isCacheHit(position, 1)) {
bytesFromCache = readFromCache(dst, position);
position += bytesFromCache;
}
if (dst.hasRemaining() && position < size) {
int directBytesRead = readDirectly(dst, position);
return bytesFromCache + directBytesRead;
} else {
return bytesFromCache;
}
}
private void loadCache(long requestPos, int requestSize) throws IOException {
int cacheSize;
long cacheStart;
var lastCacheEnd = cache != null ? this.cacheStart + cache.length : -1;
if (requestSize > SEQ_READ_THRESHOLD || lastCacheEnd == requestPos) {
cacheSize = SEQ_READ_CACHE_SIZE;
cacheStart = requestPos;
} else {
cacheSize = RANDOM_READ_CACHE_SIZE;
cacheStart = Math.max(0, requestPos - cacheSize / 2);
}
loadCacheAt(cacheStart, cacheSize);
}
private void loadCacheAt(long cacheStart, int cacheSize) throws IOException {
long maxEnd = Math.min(cacheStart + cacheSize, size);
cacheStart = Math.max(0, maxEnd - cacheSize);
var buffer = ByteBuffer.allocate((int) (maxEnd - cacheStart));
var bytesRead = readDirectly(buffer, cacheStart);
if (bytesRead != buffer.capacity()) {
throw new IOException("Failed to fill cache.");
}
cache = buffer.array();
this.cacheStart = cacheStart;
}
private boolean isCacheHit(long pos, int bytesToRead) {
if (cache == null) return false;
long cacheEnd = cacheStart + cache.length;
long readEnd = Math.min(pos + bytesToRead, size);
return pos >= cacheStart && readEnd <= cacheEnd;
}
private int readFromCache(ByteBuffer dst, long position) {
long relativePos = position - cacheStart;
int available = (int) Math.min(dst.remaining(), cache.length - relativePos);
dst.put(cache, (int) relativePos, available);
return available;
}
private int readDirectly(ByteBuffer dst, long position) throws IOException {
try (var source = streamRead(position, dst.remaining());
var channel = Channels.newChannel(source.inputStream())) {
int totalBytesRead = 0;
while (true) {
int bytesRead = channel.read(dst);
if (bytesRead <= 0) {
break;
}
totalBytesRead += bytesRead;
}
return totalBytesRead;
}
}
public BufferedSource streamRead(long position, long length) throws IOException {
long endPosition = Math.min(position + length, size) + startOffset;
var request = new Request.Builder()
.url(url)
.header("Range", "bytes=" + (startOffset + position) + "-" + (endPosition - 1))
.build();
var response = client.newCall(request).execute();
if (response.code() != 206) {
response.close();
throw new IOException("Unexpected response code " + response.code());
}
return response.body().source();
}
@Override
public long position() {
return position;
}
@Override
public SeekableByteChannel position(long newPosition) throws IOException {
if (!open) throw new ClosedChannelException();
if (newPosition < 0) {
throw new IllegalArgumentException("Position out of bounds: " + newPosition);
}
position = newPosition;
return this;
}
@Override
public long size() {
return size;
}
@Override
public boolean isOpen() {
return open;
}
@Override
public void close() {
open = false;
cache = null;
}
@Override
public int write(ByteBuffer src) {
throw new NonWritableChannelException();
}
@Override
public SeekableByteChannel truncate(long size) {
throw new NonWritableChannelException();
}
}