在上一篇java动态编译 (java在线执行代码后端实现原理(一))文章中实现了 字符串编译成字节码,而后经过反射来运行代码的demo。这一篇文章提供一个如何防止死循环的代码占用cpu的问题。html
思路:因为CustomStringJavaCompiler
中重定向了System.out
的输出位置,确定不能有多线程并发的状况,不然会照成System.out
输出内容错乱,因此我用了 Executors.newFixedThreadPool(1)
, 经过Future模式来获取结果,我自定义了一个CustomCallable
来处理核心逻辑,在call方法中从新new 了一个Thread来编译并执行代码,而后经过join等待N秒以后强制stop掉正在运行的线程。这样就能及时的kill掉动态运行的代码。java
CustomStringJavaCompiler 编译核心类sql
package compiler.mydemo; import javax.tools.Diagnostic; import javax.tools.DiagnosticCollector; import javax.tools.FileObject; import javax.tools.ForwardingJavaFileManager; import javax.tools.JavaCompiler; import javax.tools.JavaFileManager; import javax.tools.JavaFileObject; import javax.tools.SimpleJavaFileObject; import javax.tools.StandardJavaFileManager; import javax.tools.ToolProvider; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.io.PrintStream; import java.io.UnsupportedEncodingException; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.net.URI; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; /** * Create by andy on 2018-12-06 21:25 */ public class CustomStringJavaCompiler { //类全名 private String fullClassName; private String sourceCode; //存放编译以后的字节码(key:类全名,value:编译以后输出的字节码) private Map<String, ByteJavaFileObject> javaFileObjectMap = new ConcurrentHashMap<>(); //获取java的编译器 private JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); //存放编译过程当中输出的信息 private DiagnosticCollector<JavaFileObject> diagnosticsCollector = new DiagnosticCollector<>(); //执行结果(控制台输出的内容) private String runResult; //编译耗时(单位ms) private long compilerTakeTime; //运行耗时(单位ms) private long runTakeTime; public CustomStringJavaCompiler(String sourceCode) { this.sourceCode = sourceCode; this.fullClassName = getFullClassName(sourceCode); } /** * 编译字符串源代码,编译失败在 diagnosticsCollector 中获取提示信息 * * @return true:编译成功 false:编译失败 */ public boolean compiler() { long startTime = System.currentTimeMillis(); //标准的内容管理器,更换成本身的实现,覆盖部分方法 StandardJavaFileManager standardFileManager = compiler.getStandardFileManager(diagnosticsCollector, null, null); JavaFileManager javaFileManager = new StringJavaFileManage(standardFileManager); //构造源代码对象 JavaFileObject javaFileObject = new StringJavaFileObject(fullClassName, sourceCode); //获取一个编译任务 JavaCompiler.CompilationTask task = compiler.getTask(null, javaFileManager, diagnosticsCollector, null, null, Arrays.asList(javaFileObject)); //设置编译耗时 compilerTakeTime = System.currentTimeMillis() - startTime; return task.call(); } /** * 执行main方法,重定向System.out.print */ public void runMainMethod() throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, UnsupportedEncodingException { PrintStream out = System.out; try { long startTime = System.currentTimeMillis(); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); PrintStream printStream = new PrintStream(outputStream); //PrintStream PrintStream = new PrintStream("/Users/andy/Desktop/tem.sql"); //输出到文件 System.setOut(printStream); //测试kill线程暂时屏蔽 StringClassLoader scl = new StringClassLoader(); Class<?> aClass = scl.findClass(fullClassName); Method main = aClass.getMethod("main", String[].class); Object[] pars = new Object[]{1}; pars[0] = new String[]{}; main.invoke(null, pars); //调用main方法 //设置运行耗时 runTakeTime = System.currentTimeMillis() - startTime; //设置打印输出的内容 runResult = new String(outputStream.toByteArray(), "utf-8"); } finally { //还原默认打印的对象 System.setOut(out); } } /** * @return 编译信息(错误 警告) */ public String getCompilerMessage() { StringBuilder sb = new StringBuilder(); List<Diagnostic<? extends JavaFileObject>> diagnostics = diagnosticsCollector.getDiagnostics(); for (Diagnostic diagnostic : diagnostics) { sb.append(diagnostic.toString()).append("\r\n"); } return sb.toString(); } /** * @return 控制台打印的信息 */ public String getRunResult() { return runResult; } public long getCompilerTakeTime() { return compilerTakeTime; } public long getRunTakeTime() { return runTakeTime; } /** * 获取类的全名称 * * @param sourceCode 源码 * @return 类的全名称 */ public static String getFullClassName(String sourceCode) { String className = ""; Pattern pattern = Pattern.compile("package\\s+\\S+\\s*;"); Matcher matcher = pattern.matcher(sourceCode); if (matcher.find()) { className = matcher.group().replaceFirst("package", "").replace(";", "").trim() + "."; } pattern = Pattern.compile("class\\s+\\S+\\s+\\{"); matcher = pattern.matcher(sourceCode); if (matcher.find()) { className += matcher.group().replaceFirst("class", "").replace("{", "").trim(); } return className; } /** * 自定义一个字符串的源码对象 */ private class StringJavaFileObject extends SimpleJavaFileObject { //等待编译的源码字段 private String contents; //java源代码 => StringJavaFileObject对象 的时候使用 public StringJavaFileObject(String className, String contents) { super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), Kind.SOURCE); this.contents = contents; } //字符串源码会调用该方法 @Override public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException { return contents; } } /** * 自定义一个编译以后的字节码对象 */ private class ByteJavaFileObject extends SimpleJavaFileObject { //存放编译后的字节码 private ByteArrayOutputStream outPutStream; public ByteJavaFileObject(String className, Kind kind) { super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), kind); } //StringJavaFileManage 编译以后的字节码输出会调用该方法(把字节码输出到outputStream) @Override public OutputStream openOutputStream() { outPutStream = new ByteArrayOutputStream(); return outPutStream; } //在类加载器加载的时候须要用到 public byte[] getCompiledBytes() { return outPutStream.toByteArray(); } } /** * 自定义一个JavaFileManage来控制编译以后字节码的输出位置 */ private class StringJavaFileManage extends ForwardingJavaFileManager { StringJavaFileManage(JavaFileManager fileManager) { super(fileManager); } //获取输出的文件对象,它表示给定位置处指定类型的指定类。 @Override public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException { ByteJavaFileObject javaFileObject = new ByteJavaFileObject(className, kind); javaFileObjectMap.put(className, javaFileObject); return javaFileObject; } } /** * 自定义类加载器, 用来加载动态的字节码 */ private class StringClassLoader extends ClassLoader { @Override protected Class<?> findClass(String name) throws ClassNotFoundException { ByteJavaFileObject fileObject = javaFileObjectMap.get(name); if (fileObject != null) { byte[] bytes = fileObject.getCompiledBytes(); return defineClass(name, bytes, 0, bytes.length); } try { return ClassLoader.getSystemClassLoader().loadClass(name); } catch (Exception e) { return super.findClass(name); } } } }
CustomCallable 调用编译并运行,设置超时时间后端
package compiler.mydemo; import java.lang.reflect.InvocationTargetException; import java.util.concurrent.Callable; /** * Create by andy on 2018-12-07 13:10 */ public class CustomCallable implements Callable<RunInfo> { private String sourceCode; public CustomCallable(String sourceCode) { this.sourceCode = sourceCode; } //方案1 //@Override //public RunInfo call() throws Exception { // System.out.println("开始执行call" + LocalTime.now()); // RunInfo runInfo = new RunInfo(); // CustomStringJavaCompiler compiler = new CustomStringJavaCompiler(sourceCode); // if (compiler.compiler()) { // runInfo.setCompilerSuccess(true); // try { // compiler.runMainMethod(); // runInfo.setRunSuccess(true); // runInfo.setRunTakeTime(compiler.getRunTakeTime()); // runInfo.setRunMessage(compiler.getRunResult()); //获取运行的时候输出内容 // } catch (Exception e) { // e.printStackTrace(); // runInfo.setRunSuccess(false); // runInfo.setRunMessage(e.getMessage()); // } // } else { // //编译失败 // runInfo.setCompilerSuccess(false); // } // runInfo.setCompilerTakeTime(compiler.getCompilerTakeTime()); // runInfo.setCompilerMessage(compiler.getCompilerMessage()); // System.out.println("call over" + LocalTime.now()); // return runInfo; //} //方案2 @Override public RunInfo call() throws Exception { RunInfo runInfo = new RunInfo(); Thread t1 = new Thread(() -> realCall(runInfo)); t1.start(); try { t1.join(3000); //等待3秒 } catch (InterruptedException e) { e.printStackTrace(); } //无论有没有正常执行完成,强制中止t1 t1.stop(); return runInfo; } private void realCall(RunInfo runInfo) { CustomStringJavaCompiler compiler = new CustomStringJavaCompiler(sourceCode); if (compiler.compiler()) { runInfo.setCompilerSuccess(true); try { compiler.runMainMethod(); runInfo.setRunSuccess(true); runInfo.setRunTakeTime(compiler.getRunTakeTime()); runInfo.setRunMessage(compiler.getRunResult()); //获取运行的时候输出内容 } catch (InvocationTargetException e) { //反射调用异常了,是由于超时的线程被强制stop了 if ("java.lang.ThreadDeath".equalsIgnoreCase(e.getCause().toString())) { return; } } catch (Exception e) { e.printStackTrace(); runInfo.setRunSuccess(false); runInfo.setRunMessage(e.getMessage()); } } else { //编译失败 runInfo.setCompilerSuccess(false); } runInfo.setCompilerTakeTime(compiler.getCompilerTakeTime()); runInfo.setCompilerMessage(compiler.getCompilerMessage()); runInfo.setTimeOut(false); //走到这一步表明没有超时 } }
RunInfo 动态编译、运行信息的bean多线程
public class RunInfo { //true:表明超时 private Boolean timeOut; private Long compilerTakeTime; private String compilerMessage; private Boolean compilerSuccess; private Long runTakeTime; private String runMessage; private Boolean runSuccess; //省略get和set方法 }
CompilerUtil 把一整套流程封装了一个工具类并发
package compiler.mydemo; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; /** * Create by andy on 2018-12-07 16:32 */ public class CompilerUtil { //这里用一个线程是由于防止System.out输出内容错乱 private static ExecutorService pool = Executors.newFixedThreadPool(1); public static RunInfo getRunInfo(String javaSourceCode) { RunInfo runInfo; CustomCallable compilerAndRun = new CustomCallable(javaSourceCode); Future<RunInfo> future = pool.submit(compilerAndRun); //方案1 try { runInfo = future.get(); } catch (Exception e) { e.printStackTrace(); //代码编译或者运行超时 runInfo = new RunInfo(); runInfo.setTimeOut(true); } //方案2:不可行的缘由:future.get超时会有问题,因为线程池只有1个线程,同时提交10个任务, 当前面几个任务执行时间很长,后面调用get就会立马失败,也就是说get的超时时间是从调用get开始算的,并非线程真正执行时间开始计算的 //try { // runInfo = future.get(5, TimeUnit.SECONDS); // return runInfo; //} catch (InterruptedException e) { // System.out.println("future在睡着时被打断"); // e.printStackTrace(); //} catch (ExecutionException e) { // System.out.println("future在尝试取得任务结果时出错"); // e.printStackTrace(); //} catch (TimeoutException e) { // System.out.println("future时间超时"); // e.printStackTrace(); // future.cancel(true); //} //runInfo = new RunInfo(); //runInfo.setTimeOut(true); return runInfo; } }
测试类:app
package compiler.mydemo; /** * Create by andy on 2018-12-10 10:43 */ public class Test3 { public static void main(String[] args) throws InterruptedException { String loop = "public class HelloWorld {\n" + " public static void main(String[] args) {\n" + " while(true){\n" + //" System.out.println(\"Hello World!\");\n" + " }\n" + " \n" + " }\n" + "}"; String sleep_loop = "public class HelloWorld {\n" + " public static void main(String[] args) {\n" + " try {\n" + " Thread.sleep(6000);\n" + " } catch (InterruptedException e) {\n" + " e.printStackTrace();\n" + " }\n" + " System.out.println(\"Hello World!\");\n" + " while(true){\n" + //" System.out.println(\"Hello World!\");\n" + " }\n" + " }\n" + "}"; String ok = "public class HelloWorld {\n" + " public static void main(String[] args) {\n" + " System.out.println(\"Hello World!\");\n" + " }\n" + "}"; TestRun t = new TestRun(ok, "thread:ok"); t.start(); TestRun t1 = new TestRun(loop, "thread:loop:"); t1.start(); // TestRun t2 = new TestRun(sleep_loop, "thread:sleep_loop:"); t2.start(); } } class TestRun extends Thread { String code; TestRun(String code, String name) { this.code = code; super.setName(name); } @Override public void run() { System.out.println(CompilerUtil.getRunInfo(code)); } }