diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..35dc434
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+.idea
+*/target
+*.iml
diff --git a/README.md b/README.md
index 7e4a0b6..5af14c5 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,58 @@
-# unit-test-gene
+# 测试源代码生成插件
-单元测试生成工具
\ No newline at end of file
+## 说明
+本插件可以根据源文件自动生成测试代码模板,用于统一测试代码编写风格。
+
+**是什么**
+* 是一个快速生成测试代码的模板的工具。
+* 是一个规范测试代码的编写风格的工具。
+
+**不是什么**
+* 不是一个开箱即用的测试代码生成工具,需要结合需求修改生成后的代码的**方法输入参数**和**方法预想执行结果**。
+* 不是一个用于快速实现覆盖率要求指标的工具,需要结合需求修改生成后的代码来达到覆盖率指标要求。
+
+## 依赖
+源码工程依赖
+* [强制] junit4
+* [可选] mockito
+
+## 安装
+```shell
+cd unitestgen-maven-plugin
+mvn install
+```
+
+## 使用
+
+* 运行命令生成测试文件。
+
+```shell
+cd unitestgen-sample
+
+# 查看插件使用方法
+mvn github.plugin:unitestgen-maven-plugin:1.0:help
+
+# 常用生成测试代码命令
+# 生成工程源码的全部测试代码 如果存在 则追加测试类到已有文件
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene
+
+# 生成工程源码的全部测试代码 使用 mockito 作为 mock 工具
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dmock=mockito
+
+# 生成工程源码的全部测试代码 如果存在 则替换原有文件 重新生成
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dmode=overwrite
+
+# 生成工程源码的全部测试代码 指定生成文件名后缀
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dsuffix=AutoTest
+
+# 生成工程源码的指定包下面的测试代码 如果存在 则追加测试类到已有文件
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dincludes="github.plugin.unitestgen.model"
+
+# 生成工程源码的包含指定类的测试代码 如果存在 则追加测试类到已有文件
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dincludes="github.plugin.unitestgen.model.ParseModel"
+
+# 生成工程源码的指定包内并且排除指定类的测试代码 如果存在 则追加测试类到已有文件
+mvn github.plugin:unitestgen-maven-plugin:1.0:gene -Dincludes="github.plugin.unitestgen.model" -Dexcludes="github.plugin.unitestgen.model.ParseModel2"
+```
+## 优化
+* 使用 mockito 时 可以自动生成 mock 语句
diff --git a/unitestgen-maven-plugin/pom.xml b/unitestgen-maven-plugin/pom.xml
new file mode 100644
index 0000000..08598cd
--- /dev/null
+++ b/unitestgen-maven-plugin/pom.xml
@@ -0,0 +1,69 @@
+
+
+
+ 4.0.0
+
+ maven-plugin
+
+ github.plugin
+ unitestgen-maven-plugin
+ 1.0
+
+
+ UTF-8
+ 1.8
+
+
+
+
+ org.apache.maven
+ maven-plugin-api
+ 3.9.0
+ provided
+
+
+ org.apache.maven
+ maven-project
+ 2.2.1
+ provided
+
+
+ org.apache.maven.plugin-tools
+ maven-plugin-annotations
+ 3.8.1
+ provided
+
+
+ com.github.javaparser
+ javaparser-core
+ 3.25.5
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.11.0
+
+
+ 1.8
+ UTF-8
+
+
+
+ org.apache.maven.plugins
+ maven-resources-plugin
+ 3.3.1
+
+
+ org.apache.maven.plugins
+ maven-plugin-plugin
+ 3.9.0
+
+
+
+
\ No newline at end of file
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/GeneMojo.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/GeneMojo.java
new file mode 100644
index 0000000..7eee13e
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/GeneMojo.java
@@ -0,0 +1,132 @@
+package github.plugin.unitestgen.mojo;
+
+import github.plugin.unitestgen.tool.GeneTool;
+import github.plugin.unitestgen.util.FileUtils;
+import github.plugin.unitestgen.util.StringUtils;
+import org.apache.maven.plugin.AbstractMojo;
+import org.apache.maven.plugin.logging.Log;
+import org.apache.maven.plugins.annotations.Mojo;
+import org.apache.maven.plugins.annotations.Parameter;
+import org.apache.maven.project.MavenProject;
+
+import java.io.File;
+import java.util.List;
+
+@Mojo(name = "gene")
+public class GeneMojo extends AbstractMojo {
+
+ public static final String DEFAULT_TEST_FILENAME_SUFFIX = "Test";
+
+ public static final String MODE_APPEND = "append";
+
+ public static final String MODE_OVERWRITE = "overwrite";
+
+ private Log log;
+
+ @Parameter(defaultValue = "${project}")
+ private MavenProject project;
+
+ @Parameter(property = "mock")
+ private String mock;
+
+ @Parameter(property = "mode", defaultValue = MODE_APPEND)
+ private String mode;
+
+ @Parameter(property = "suffix", defaultValue = DEFAULT_TEST_FILENAME_SUFFIX)
+ private String suffix;
+
+ @Parameter(property = "includes")
+ private String includes;
+
+ @Parameter(property = "excludes")
+ private String excludes;
+
+ @Override
+ public void execute() {
+ log = getLog();
+ infoParams();
+ generate();
+ }
+
+ private void generate() {
+ String srcRootPath = project.getCompileSourceRoots().get(0).toString();
+ String testRootPath = project.getTestCompileSourceRoots().get(0).toString();
+ FileUtils.walk(
+ new File(srcRootPath),
+ "java",
+ srcFile -> generateTestFile(srcFile, srcRootPath, testRootPath)
+ );
+ }
+
+ private void infoParams() {
+ log.info("mvn param mock: " + mock);
+ log.info("mvn param mode: " + mode);
+ log.info("mvn param suffix: " + suffix);
+ log.info("mvn param includes: " + includes);
+ log.info("mvn param excludes: " + excludes);
+ }
+
+ private void generateTestFile(File srcFile, String srcRootPath, String testRootPath) {
+ String srcClassFullName = calcClassFullName(srcFile, srcRootPath);
+ if (checkGenerate(srcClassFullName)) {
+ File testFile = calcTestFile(srcFile, srcRootPath, testRootPath);
+ generate(srcFile, testFile);
+ }
+ }
+
+ private boolean checkGenerate(String srcClassFullName) {
+ if (StringUtils.isBlank(includes) && StringUtils.isBlank(excludes)) {
+ return true;
+ }
+
+ if (!StringUtils.isBlank(includes)) {
+ if (StringUtils.includeStartsWith(includes, srcClassFullName)) {
+ if (!StringUtils.isBlank(excludes)) {
+ return !StringUtils.includeStartsWith(excludes, srcClassFullName);
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ if (!StringUtils.isBlank(excludes)) {
+ return !StringUtils.includeStartsWith(excludes, srcClassFullName);
+ }
+
+ return false;
+ }
+
+ private void generate(File srcFile, File testFile) {
+ if (testFile.exists()) {
+ if (MODE_OVERWRITE.equals(mode)) {
+ log.info("overwrite test file: " + testFile);
+ new GeneTool(mock, srcFile, testFile, false, log).generate();
+ } else {
+ log.info("append test file: " + testFile);
+ new GeneTool(mock, srcFile, testFile, true, log).generate();
+ }
+ } else {
+ FileUtils.createParentDir(testFile);
+ log.info("create test file: " + testFile);
+ new GeneTool(mock, srcFile, testFile, false, log).generate();
+ }
+ }
+
+ private String calcClassFullName(File srcFile, String srcRootPath) {
+ String srcRootTemp = srcRootPath.replaceAll("\\\\", ".") + ".";
+ String srcFilePath = srcFile.getAbsolutePath();
+ String srcPackageTemp = srcFilePath.replaceAll("\\\\", ".");
+ return srcPackageTemp.replaceAll(srcRootTemp, "");
+ }
+
+ private File calcTestFile(File srcFile, String srcRootPath, String testRootPath) {
+ String srcFileName = srcFile.getName();
+ List srcFileNameSplit = StringUtils.split(srcFileName, "\\.");
+ String testFileName = srcFileNameSplit.get(0) + suffix + "." + srcFileNameSplit.get(1);
+ String srcFilePath = srcFile.getAbsolutePath();
+ String testFilePath = srcFilePath.replace(srcRootPath, testRootPath).replace(srcFileName, testFileName);
+ return new File(testFilePath);
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/HelpMojo.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/HelpMojo.java
new file mode 100644
index 0000000..a076e06
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/mojo/HelpMojo.java
@@ -0,0 +1,40 @@
+package github.plugin.unitestgen.mojo;
+
+import org.apache.maven.plugin.AbstractMojo;
+import org.apache.maven.plugins.annotations.Mojo;
+
+@Mojo(name = "help")
+public class HelpMojo extends AbstractMojo {
+
+ @Override
+ public void execute() {
+ String help
+ = "This plugin has 2 goals:\n\n"
+ + "Command description:\n\n"
+ + " [unitestgen:help]\n"
+ + " Display help information on unitestgen-maven-plugin.\n\n"
+ + " [unitestgen:gene]\n"
+ + " Generate test class by given parameters.\n"
+ + " Parameter description:\n\n"
+ + " [mock] Create test file use mock util,\n"
+ + " required false, support value: [mockito], default value: [null].\n"
+ + " use: mvn unitestgen:gene -Dmock=mockito\n"
+ + " [mode] If test file exists, append or overwrite file by given mode,\n"
+ + " required false, support value: [append|overwrite], default value: [append].\n"
+ + " use: mvn unitestgen:gene -Dmode=overwrite\n"
+ + " [suffix] Create test file name suffix,\n"
+ + " required false, support value: [anyString], default value: [Test].\n"
+ + " use: mvn unitestgen:gene -suffix=Test\n"
+ + " [includes] Create a test file by specifying the included package name or class full name, \n"
+ + " required false, support value: [|], default value: [null].\n"
+ + " use: mvn unitestgen:gene -Dincludes=|\n"
+ + " [excludes] Create a test file by specifying the excluded package name or class full name, \n"
+ + " required false, support value: [|], default value: [null].\n"
+ + " use: mvn unitestgen:gene -Dexcludes=|\n\n"
+ + " All the above parameters can be freely combined.\n"
+ + " use: mvn unitestgen:gene -Dmock=mockito -Dcover=true -Dincludes=| -Dexcludes=|\n";
+
+ getLog().info(help);
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/tool/GeneTool.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/tool/GeneTool.java
new file mode 100644
index 0000000..f42bd47
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/tool/GeneTool.java
@@ -0,0 +1,804 @@
+package github.plugin.unitestgen.tool;
+
+import com.github.javaparser.StaticJavaParser;
+import com.github.javaparser.ast.CompilationUnit;
+import com.github.javaparser.ast.Node;
+import com.github.javaparser.ast.NodeList;
+import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
+import com.github.javaparser.ast.body.FieldDeclaration;
+import com.github.javaparser.ast.body.MethodDeclaration;
+import com.github.javaparser.ast.body.Parameter;
+import com.github.javaparser.ast.body.VariableDeclarator;
+import com.github.javaparser.ast.expr.AnnotationExpr;
+import com.github.javaparser.ast.expr.Expression;
+import com.github.javaparser.ast.expr.MemberValuePair;
+import com.github.javaparser.ast.expr.MethodCallExpr;
+import com.github.javaparser.ast.expr.NameExpr;
+import com.github.javaparser.ast.expr.VariableDeclarationExpr;
+import com.github.javaparser.ast.stmt.BlockStmt;
+import com.github.javaparser.ast.stmt.CatchClause;
+import com.github.javaparser.ast.stmt.ExpressionStmt;
+import com.github.javaparser.ast.stmt.IfStmt;
+import com.github.javaparser.ast.stmt.Statement;
+import com.github.javaparser.ast.stmt.TryStmt;
+import com.github.javaparser.ast.type.ClassOrInterfaceType;
+import com.github.javaparser.ast.type.Type;
+import com.github.javaparser.ast.type.TypeParameter;
+import github.plugin.unitestgen.util.AstUtils;
+import github.plugin.unitestgen.util.ExprUtils;
+import github.plugin.unitestgen.util.FileUtils;
+import github.plugin.unitestgen.util.NameUtils;
+import github.plugin.unitestgen.util.StringUtils;
+import org.apache.maven.plugin.logging.Log;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.StringJoiner;
+
+public class GeneTool {
+
+ public final static String MOCKITO = "mockito";
+
+ private final Log log;
+
+ private final File srcFile;
+
+ private final File testFile;
+
+ private final String mock;
+
+ private final boolean append;
+
+ private final Map methodNameTimes = new HashMap<>();
+
+ private boolean privateMethodExists = false;
+
+ private boolean privateFieldExists = false;
+
+ private boolean voidMethodExists = false;
+
+ public GeneTool(String mock, File srcFile, File testFile, boolean append, Log log) {
+ this.mock = mock;
+ this.srcFile = srcFile;
+ this.testFile = testFile;
+ this.append = append;
+ this.log = log;
+ }
+
+ public void generate() {
+ CompilationUnit testUnit = createTestUnit();
+ FileUtils.output(testFile, testUnit.toString());
+ }
+
+ private CompilationUnit createTestUnit() {
+ // parse src unit
+ CompilationUnit srcUnit = AstUtils.getUnit(srcFile);
+
+ // create test unit if not exists
+ CompilationUnit testUnit = createIfNotExistsTestUnit();
+
+ // create test class if not exists
+ ClassOrInterfaceDeclaration testClass = createIfNotExistsTestClass(srcUnit, testUnit);
+
+ // create test class field if not exists
+ createIfNotExistsTestClassField(srcUnit, testClass);
+
+ // create setUp method if not exists
+ createIfNotExistsSetUpMethod(srcUnit, testClass);
+
+ // create tearDown method if not exists
+ createIfNotExistsTearDownMethod(testClass);
+
+ // create test_xxx_branch_xxx method if not exists
+ createIfNotExistsTestMethod(srcUnit, testClass);
+
+ // create reflectField field
+ if (privateFieldExists) {
+ String className = AstUtils.getClassName(srcUnit);
+ createIfNotExistsReflectField(testClass, className);
+ }
+
+ // create reflectMethod method
+ if (privateMethodExists) {
+ String className = AstUtils.getClassName(srcUnit);
+ createIfNotExistsReflectMethod(testClass, className);
+ }
+
+ // create stackTrace method
+ if (voidMethodExists) {
+ createIfNotExistsGetStackTrace(testClass);
+ }
+
+ return testUnit;
+ }
+
+ private CompilationUnit createIfNotExistsTestUnit() {
+ if (append) {
+ return AstUtils.getUnit(testFile);
+ } else {
+ return new CompilationUnit();
+ }
+ }
+
+ private ClassOrInterfaceDeclaration createIfNotExistsTestClass(
+ CompilationUnit srcUnit,
+ CompilationUnit testUnit
+ ) {
+ String testClassName = StringUtils.splitFirst(testFile.getName(), "\\.");
+ Optional optional = testUnit.getClassByName(testClassName);
+ return optional.orElseGet(() -> createTestClass(srcUnit, testUnit));
+ }
+
+ private ClassOrInterfaceDeclaration createTestClass(
+ CompilationUnit srcUnit,
+ CompilationUnit testUnit
+ ) {
+ // create package
+ String packageName = AstUtils.getPackageName(srcUnit);
+ createPackage(testUnit, packageName);
+
+ // create class
+ String className = StringUtils.splitFirst(testFile.getName(), "\\.");
+ ClassOrInterfaceDeclaration clazz = testUnit.addClass(className);
+
+ // create class annotation
+ if (MOCKITO.equals(mock)) {
+ // add import
+ String importName = "org.mockito.junit.MockitoJUnitRunner";
+ createImport(testUnit, importName);
+
+ // add annotation
+ String name = "value";
+ String value = "MockitoJUnitRunner.class";
+ NodeList memberValues = AstUtils.createMemberValues(name, value);
+ String annotationName = "org.junit.runner.RunWith";
+ AnnotationExpr annotation = AstUtils.createAnnotationExpr(clazz, annotationName, memberValues);
+ AstUtils.addAnnotation(clazz, annotation);
+ info("create method annotation: " + annotationName);
+ }
+
+ // print log
+ info("create be test class: " + clazz.getNameAsString());
+ debug(clazz);
+
+ return clazz;
+ }
+
+ private void createIfNotExistsTestClassField(
+ CompilationUnit srcUnit,
+ ClassOrInterfaceDeclaration testClass
+ ) {
+ // create test class be tested field if not exists
+ String className = AstUtils.getClassName(srcUnit);
+ String beTestFieldName = NameUtils.toCamelCase(className);
+ if (!AstUtils.checkFieldExists(testClass, beTestFieldName)) {
+ if (!AstUtils.checkUtilClass(srcUnit)) {
+ ClassOrInterfaceType beTestFieldType = new ClassOrInterfaceType(null, className);
+ if (!MOCKITO.equals(mock)) {
+ createField(srcUnit, testClass, beTestFieldType, beTestFieldName);
+ } else {
+ String annotationName = "org.mockito.InjectMocks";
+ createMockedField(srcUnit, testClass, beTestFieldType, beTestFieldName, annotationName);
+ }
+ }
+ }
+
+ // create test class inject field if not exists
+ AstUtils.consumeField(srcUnit, (injectField, injectFieldType) -> {
+ String injectFieldName = AstUtils.getName(AstUtils.getVariableDeclarator(injectField));
+ if (!AstUtils.checkFieldExists(testClass, injectFieldName)) {
+ if (!MOCKITO.equals(mock)) {
+ createField(srcUnit, testClass, injectFieldType, injectFieldName);
+ } else {
+ String annotationName = "org.mockito.Mock";
+ createMockedField(srcUnit, testClass, injectFieldType, injectFieldName, annotationName);
+ }
+ }
+ });
+ }
+
+ private void createIfNotExistsSetUpMethod(
+ CompilationUnit srcUnit,
+ ClassOrInterfaceDeclaration testClass
+ ) {
+ // check method setUp
+ String methodName = "setUp";
+ if (!AstUtils.checkMethodExists(testClass, methodName)) {
+ // create setUp method
+ String annotationName = "org.junit.Before";
+ MethodDeclaration method = createMethod(testClass, methodName, annotationName);
+
+ // create setUp method content
+ BlockStmt blockStmt = createSetUpMethodContent(srcUnit);
+ method.setBody(blockStmt);
+
+ // debug method
+ debug(method);
+ }
+ }
+
+ private BlockStmt createSetUpMethodContent(CompilationUnit srcUnit) {
+ BlockStmt blockStmt = new BlockStmt();
+ AstUtils.consumeField(srcUnit, ((declaration, type) -> {
+ if (declaration.isPublic()) {
+ // add setup content
+ String classFieldName = NameUtils.toCamelCase(AstUtils.getClassName(srcUnit));
+ String fieldName = AstUtils.getName(AstUtils.getVariableDeclarator(declaration));
+ String field = classFieldName + "." + fieldName + " = " + fieldName + ";";
+ Statement fieldStmt = StaticJavaParser.parseStatement(field);
+ blockStmt.addStatement(fieldStmt);
+ } else {
+ privateFieldExists = true;
+ // add setup content
+ String mockFieldName = AstUtils.getName(AstUtils.getVariableDeclarator(declaration));
+ NameExpr fieldName = new NameExpr("\"" + mockFieldName + "\"");
+ NameExpr fieldValue = new NameExpr(mockFieldName);
+ MethodCallExpr methodCallExpr = new MethodCallExpr("reflectField", fieldName, fieldValue);
+ blockStmt.addStatement(methodCallExpr);
+ }
+ }));
+ return blockStmt;
+ }
+
+ private void createIfNotExistsTearDownMethod(ClassOrInterfaceDeclaration testClass) {
+ // check method tearDown
+ String methodName = "tearDown";
+ if (!AstUtils.checkMethodExists(testClass, methodName)) {
+ // create method tearDown
+ String annotationName = "org.junit.After";
+ createMethod(testClass, methodName, annotationName);
+ }
+ }
+
+ private void createIfNotExistsTestMethod(
+ CompilationUnit srcUnit,
+ ClassOrInterfaceDeclaration testClass
+ ) {
+ for (MethodDeclaration srcMethod : srcUnit.findAll(MethodDeclaration.class)) {
+ srcMethod.getBody().ifPresent((blockStmt) -> {
+ List childNodes = blockStmt.getChildNodes();
+ if (childNodes.isEmpty()) {
+ createTestMethod(testClass, srcMethod, "");
+ return;
+ }
+ if (
+ childNodes.size() == 1
+ && (!(childNodes.get(0) instanceof IfStmt) && !(childNodes.get(0) instanceof TryStmt))
+ ) {
+ createTestMethod(testClass, srcMethod, "");
+ return;
+ }
+ for (Node childNode : childNodes) {
+ createTestMethod(testClass, srcMethod, "_branch", childNode, 0);
+ }
+ });
+ }
+ }
+
+ private int createTestMethod(
+ ClassOrInterfaceDeclaration testClass,
+ MethodDeclaration srcMethod,
+ String methodNameSuffix,
+ Node node,
+ int branchCnt
+ ) {
+ // if else
+ if (node instanceof IfStmt) {
+ branchCnt++;
+ // if
+ Expression condition = ((IfStmt) node).getCondition();
+ String expression = ExprUtils.expression(condition.toString());
+ String methodNameSuffixIf = methodNameSuffix + "_if_" + expression;
+ String methodNameSuffixIfFinal = ExprUtils.replaceDoubleUnderLine(methodNameSuffixIf);
+ createTestMethod(testClass, srcMethod, methodNameSuffixIfFinal);
+ String methodNameSuffixIfThen = methodNameSuffixIf + "_";
+ Statement thenStmt = ((IfStmt) node).getThenStmt();
+ for (Node thenNode : thenStmt.getChildNodes()) {
+ if (thenNode instanceof IfStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixIfThen, thenNode, branchCnt);
+ }
+ if (thenNode instanceof TryStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixIfThen, thenNode, branchCnt);
+ }
+ if (thenNode instanceof BlockStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixIfThen, thenNode, branchCnt);
+ }
+ }
+
+ // else
+ Optional elseStmtOptional = ((IfStmt) node).getElseStmt();
+ boolean hasElse = elseStmtOptional.isPresent();
+ if (hasElse) {
+ branchCnt++;
+ String methodNameSuffixElse = methodNameSuffix + "_else_" + expression;
+ String methodNameSuffixElseFinal = ExprUtils.expression(methodNameSuffixElse);
+ createTestMethod(testClass, srcMethod, methodNameSuffixElseFinal);
+ String methodNameSuffixElseThen = methodNameSuffixElse + "_";
+ Statement elseStmt = elseStmtOptional.get();
+ if (elseStmt instanceof IfStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixElseThen, elseStmt, branchCnt);
+ }
+ if (elseStmt instanceof TryStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixElseThen, elseStmt, branchCnt);
+ }
+ if (elseStmt instanceof BlockStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixElseThen, elseStmt, branchCnt);
+ }
+ }
+ }
+
+ // try catch
+ if (node instanceof TryStmt) {
+ branchCnt++;
+ // try
+ String methodNameSuffixTry = methodNameSuffix + "_try_";
+ String methodNameSuffixTryFinal = ExprUtils.expression(methodNameSuffixTry);
+ createTestMethod(testClass, srcMethod, methodNameSuffixTryFinal);
+ for (Node childNode : ((TryStmt) node).getTryBlock().getChildNodes()) {
+ if (childNode instanceof IfStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ if (childNode instanceof TryStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ if (childNode instanceof BlockStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ }
+
+ // catch
+ for (CatchClause catchClause : ((TryStmt) node).getCatchClauses()) {
+ for (Node childNode : catchClause.getChildNodes()) {
+ if (childNode instanceof IfStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ if (childNode instanceof TryStmt) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ if (childNode instanceof BlockStmt) {
+ branchCnt++;
+ Parameter parameter = (Parameter) catchClause.getChildNodes().get(0);
+ String methodNameSuffixCatch = methodNameSuffix + "_catch_" + parameter.getTypeAsString();
+ String methodNameSuffixCatchFinal = ExprUtils.expression(methodNameSuffixCatch);
+ createTestMethod(testClass, srcMethod, methodNameSuffixCatchFinal);
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffixTry, childNode, branchCnt);
+ }
+ }
+ }
+ }
+
+ // block
+ if (node instanceof BlockStmt) {
+ for (Node childNode : node.getChildNodes()) {
+ branchCnt = createTestMethod(testClass, srcMethod, methodNameSuffix, childNode, branchCnt);
+ }
+ }
+
+ return branchCnt;
+ }
+
+ private void createTestMethod(
+ ClassOrInterfaceDeclaration testClass,
+ MethodDeclaration srcMethod,
+ String methodNameSuffix
+ ) {
+ // calc test method name
+ String methodName = calcTestMethodName(srcMethod.getNameAsString(), methodNameSuffix);
+ if (AstUtils.checkMethodExists(testClass, methodName)) {
+ return;
+ }
+
+ // create test method
+ String annotationName = "org.junit.Test";
+ MethodDeclaration method = createMethod(testClass, methodName, annotationName);
+
+ // create test method content
+ BlockStmt testContent = new BlockStmt();
+
+ // create given block
+ createGivenBlock(testClass, srcMethod, testContent);
+
+ // create when block
+ createWhenBlock(srcMethod, testContent);
+
+ // create assert block
+ if (AstUtils.checkVoidReturn(srcMethod)) {
+ // assert times
+ createAssertTimesBlock(testClass, testContent);
+ } else {
+ // assert result
+ createAssertResultBlock(testClass, testContent);
+ }
+
+ // add content
+ method.setBody(testContent);
+
+ // print log
+ debug(testContent);
+ }
+
+ private String calcTestMethodName(String srcMethodName, String methodNameSuffix) {
+ String base = "test_" + NameUtils.toUnderscoreCase(srcMethodName) + methodNameSuffix;
+ if (!methodNameTimes.containsKey(base)) {
+ methodNameTimes.put(base, 1);
+ return base;
+ } else {
+ Integer times = methodNameTimes.get(base) + 1;
+ methodNameTimes.put(base, times);
+ return base + times;
+ }
+ }
+
+ private void createAssertTimesBlock(ClassOrInterfaceDeclaration testClass, BlockStmt testContent) {
+ voidMethodExists = true;
+
+ CompilationUnit unit = AstUtils.getUnit(testClass);
+ String importName = "org.junit.Assert";
+ createImport(unit, importName);
+
+ TryStmt tryStmt = new TryStmt();
+ BlockStmt tryBlock = new BlockStmt();
+ tryStmt.setTryBlock(tryBlock);
+ NodeList catchClauses = new NodeList<>();
+ tryStmt.setCatchClauses(catchClauses);
+
+ Statement throwStmt = StaticJavaParser.parseStatement("throw new RuntimeException();");
+ throwStmt.setLineComment(" TODO then assert inner method run times");
+ tryBlock.addStatement(throwStmt);
+
+ BlockStmt catchStmt = new BlockStmt();
+ Parameter parameter = new Parameter(new TypeParameter("Exception"), "exception");
+ CatchClause catchClause = new CatchClause(parameter, catchStmt);
+
+ // stack content
+ String stackStr = "String stackTrace = getStackTrace(exception);";
+ Statement stackStmt = StaticJavaParser.parseStatement(stackStr);
+ catchStmt.addStatement(stackStmt);
+ // assert content
+ String assertStr = " Assert.fail(\"Should not run here.\\n\\t\" + stackTrace);";
+ Statement assertStmt = StaticJavaParser.parseStatement(assertStr);
+ catchStmt.addStatement(assertStmt);
+
+ catchClause.setBody(catchStmt);
+ catchClauses.add(catchClause);
+
+ testContent.addStatement(tryStmt);
+ }
+
+ private void createAssertResultBlock(ClassOrInterfaceDeclaration testClass, BlockStmt testContent) {
+ // import package
+ CompilationUnit unit = AstUtils.getUnit(testClass);
+ String importName = "org.junit.Assert";
+ createImport(unit, importName);
+
+ // create expect result
+ String expectContentStr = "Object expect = null;";
+ Statement expectStmt = StaticJavaParser.parseStatement(expectContentStr);
+
+ // create modify comment
+ expectStmt.setLineComment(" TODO then");
+ testContent.addStatement(expectStmt);
+
+ // create assert content
+ String assertContentStr = "Assert.assertEquals(expect, actual);";
+ Statement assertStmt = StaticJavaParser.parseStatement(assertContentStr);
+ testContent.addStatement(assertStmt);
+ }
+
+ private void createGivenBlock(
+ ClassOrInterfaceDeclaration testClass,
+ MethodDeclaration srcMethod,
+ BlockStmt testContent
+ ) {
+ CompilationUnit testUnit = AstUtils.getUnit(testClass);
+ CompilationUnit srcUnit = AstUtils.getUnit(srcMethod);
+ NodeList parameters = srcMethod.getParameters();
+ for (int i = 0, parametersSize = parameters.size(); i < parametersSize; i++) {
+ Parameter parameter = parameters.get(i);
+ Type fieldType = parameter.getType();
+
+ // create import
+ List importNames = AstUtils.getImportNames(srcUnit, fieldType);
+ for (String importName : importNames) {
+ createImport(testUnit, importName);
+ }
+
+ // create given statement
+ String fieldName = AstUtils.getName(parameter);
+ VariableDeclarationExpr variableExpr = AstUtils.createVariableDeclarationExpr(fieldType, fieldName);
+ ExpressionStmt expressionStmt = AstUtils.createExpressionStmt(variableExpr);
+
+ if (i == 0) {
+ // first line add comment
+ String comment = " TODO given";
+ variableExpr.setLineComment(comment);
+ }
+
+ testContent.addStatement(expressionStmt);
+ }
+ }
+
+ private void createWhenBlock(MethodDeclaration srcMethod, BlockStmt testContent) {
+ Statement whenStmt;
+ if (srcMethod.isPublic()) {
+ String whenField = createCallWhenMethod(srcMethod);
+ whenStmt = StaticJavaParser.parseStatement(whenField);
+ } else {
+ privateMethodExists = true;
+ NodeList parameters = srcMethod.getParameters();
+ String whenFieldTypes = createCallReflectWhenMethodTypes(parameters);
+ Statement whenFieldTypesStmt = StaticJavaParser.parseStatement(whenFieldTypes);
+ testContent.addStatement(whenFieldTypesStmt);
+
+ String whenFieldParams = createCallReflectWhenMethodParams(parameters);
+ Statement whenFieldParamsStmt = StaticJavaParser.parseStatement(whenFieldParams);
+ testContent.addStatement(whenFieldParamsStmt);
+
+ String whenField = createCallReflectWhenMethod(srcMethod);
+ whenStmt = StaticJavaParser.parseStatement(whenField);
+ }
+ whenStmt.setLineComment(" when");
+ testContent.addStatement(whenStmt);
+ }
+
+ private String createCallReflectWhenMethodTypes(NodeList parameters) {
+ String types = typeJoinStr(parameters);
+ String callWhenMethodTypes = "Class>[] types = { " + types + "};";
+ log.debug(callWhenMethodTypes);
+ return callWhenMethodTypes;
+ }
+
+ private String createCallReflectWhenMethodParams(NodeList parameters) {
+ String params = paramJoinStr(parameters);
+ String callWhenMethodParams = "Object[] params = { " + params + "};";
+ log.debug(callWhenMethodParams);
+ return callWhenMethodParams;
+ }
+
+ private String createCallReflectWhenMethod(MethodDeclaration method) {
+ String methodName = method.getNameAsString();
+ String nameMethod = "\"" + methodName + "\"";
+ String whenStr;
+ if (AstUtils.checkVoidReturn(method)) {
+ whenStr = "reflectMethod(" + nameMethod + ", " + "types" + ", " + "params" + ");";
+ } else {
+ whenStr = "Object actual = reflectMethod(" + nameMethod + ", " + "types" + ", " + "params" + ");";
+ }
+ log.debug(whenStr);
+ return whenStr;
+ }
+
+ private String createCallWhenMethod(MethodDeclaration srcMethod) {
+
+ StringBuilder builder = new StringBuilder();
+ if (!AstUtils.checkVoidReturn(srcMethod)) {
+ builder.append("Object");
+ builder.append(" actual = ");
+ }
+
+ CompilationUnit srcUnit = AstUtils.getUnit(srcMethod);
+ String classType = AstUtils.getClassName(srcUnit);
+ String classFieldName = NameUtils.toCamelCase(classType);
+ if (srcMethod.isStatic()) {
+ builder.append(AstUtils.getClassName(srcUnit));
+ } else {
+ builder.append(classFieldName);
+ }
+
+ builder.append(".");
+ String methodName = srcMethod.getNameAsString();
+ builder.append(methodName);
+ builder.append("(");
+ String params = paramJoinStr(srcMethod.getParameters());
+ builder.append(params);
+ builder.append(");");
+ String callWhenMethod = builder.toString();
+ log.debug(callWhenMethod);
+ return callWhenMethod;
+ }
+
+ private String paramJoinStr(NodeList parameters) {
+ StringJoiner joiner = new StringJoiner(", ");
+ for (Parameter param : parameters) {
+ joiner.add(param.getName().toString());
+ }
+ return joiner.toString();
+ }
+
+ private String typeJoinStr(NodeList parameters) {
+ StringJoiner joiner = new StringJoiner(", ");
+ for (Parameter param : parameters) {
+ joiner.add(AstUtils.getTypeName(param.getType()) + ".class");
+ }
+ return joiner.toString();
+ }
+
+ private void createIfNotExistsReflectField(
+ ClassOrInterfaceDeclaration testClass,
+ String srcClassName
+ ) {
+ // check method exists
+ String methodName = "reflectField";
+ if (!AstUtils.checkMethodExists(testClass, methodName)) {
+ String importName = "java.lang.reflect.Field";
+ String methodContent = getReflectField(srcClassName);
+ createMethod(testClass, methodName, methodContent, importName);
+ }
+ }
+
+ private void createIfNotExistsReflectMethod(
+ ClassOrInterfaceDeclaration clazz,
+ String srcClassName
+ ) {
+ // check method exists
+ String methodName = "reflectMethod";
+ if (!AstUtils.checkMethodExists(clazz, methodName)) {
+ String importName = "java.lang.reflect.Method";
+ String methodContent = getReflectMethod(srcClassName);
+ createMethod(clazz, methodName, methodContent, importName);
+ }
+ }
+
+ private void createIfNotExistsGetStackTrace(ClassOrInterfaceDeclaration clazz) {
+ // check method exists
+ String methodName = "getStackTrace";
+ if (!AstUtils.checkMethodExists(clazz, methodName)) {
+ String importName1 = "java.util.Arrays";
+ String importName2 = "java.util.stream.Collectors";
+ String methodContent = getGetStackTrace();
+ createMethod(clazz, methodName, methodContent, importName1, importName2);
+ }
+ }
+
+ private static String getReflectField(String className) {
+ String fieldName = NameUtils.toCamelCase(className);
+ return "private void reflectField(String fieldName, Object fieldValue) {"
+ + " try {"
+ + " Field field = " + className + ".class.getDeclaredField(fieldName);"
+ + " field.setAccessible(true);"
+ + " field.set(" + fieldName + ", fieldValue);"
+ + " } catch (Exception e) {"
+ + " throw new RuntimeException(e);"
+ + " }"
+ + "}";
+ }
+
+ private static String getReflectMethod(String className) {
+ // create reflect method body
+ String fieldName = NameUtils.toCamelCase(className);
+ return "private Object reflectMethod(String methodName, Class>[] types, Object[] params) {"
+ + " try {"
+ + " Method method = " + className + ".class.getDeclaredMethod(methodName, types);"
+ + " method.setAccessible(true);"
+ + " return method.invoke(" + fieldName + ", params);"
+ + " } catch (Exception e) {"
+ + " throw new RuntimeException(e);"
+ + " }"
+ + "}";
+ }
+
+ private static String getGetStackTrace() {
+ return "private String getStackTrace(Exception e) {"
+ + " return Arrays.stream(e.getStackTrace())"
+ + " .map(StackTraceElement::toString)"
+ + " .collect(Collectors.joining(\"\\n\\t\"));"
+ + "}";
+ }
+
+ private void createPackage(CompilationUnit unit, String packageName) {
+ // create package
+ unit.setPackageDeclaration(packageName);
+ // print log
+ info("create package: " + packageName);
+ debug("import package " + packageName + ";");
+ }
+
+ private void createImport(CompilationUnit unit, String... importNames) {
+ for (String importName : importNames) {
+ AstUtils.addImport(unit, importName);
+ info("import package: " + importName);
+ debug("import " + importName + ";");
+ }
+ }
+
+ private void createField(
+ CompilationUnit srcUnit,
+ ClassOrInterfaceDeclaration testClass,
+ ClassOrInterfaceType fieldType,
+ String fieldName
+ ) {
+ // import package
+ AstUtils.findImportDeclaration(srcUnit, AstUtils.getName(fieldType))
+ .ifPresent(declaration -> testClass.findCompilationUnit()
+ .ifPresent(unit -> createImport(unit, declaration.getNameAsString())));
+
+ // create field
+ FieldDeclaration field = testClass.addPublicField(fieldType, fieldName);
+
+ // set default value
+ String defaultValue = AstUtils.getDefaultValue(fieldType);
+ VariableDeclarator variable = AstUtils.createVariableDeclarator(fieldType, fieldName, defaultValue);
+ field.setVariable(0, variable);
+
+ // print log
+ info("create field: " + fieldName);
+ debug(field);
+ }
+
+ private void createMockedField(
+ CompilationUnit srcUnit,
+ ClassOrInterfaceDeclaration testClass,
+ ClassOrInterfaceType fieldType,
+ String fieldName,
+ String annotationName
+ ) {
+ // import package
+ AstUtils.findImportDeclaration(srcUnit, AstUtils.getName(fieldType))
+ .ifPresent(declaration -> testClass.findCompilationUnit()
+ .ifPresent(unit -> createImport(unit, declaration.getNameAsString())));
+
+ // create field
+ FieldDeclaration field = testClass.addPublicField(fieldType, fieldName);
+
+ // add mock annotation
+ AnnotationExpr annotation = AstUtils.createAnnotationExpr(testClass, annotationName);
+ AstUtils.addAnnotation(field, annotation);
+
+ // print log
+ info("create field: " + fieldName);
+ debug(field);
+ }
+
+ private MethodDeclaration createMethod(
+ ClassOrInterfaceDeclaration clazz,
+ String methodName,
+ String annotationName
+ ) {
+ // create method
+ MethodDeclaration method = AstUtils.createMethodDeclaration(clazz, methodName);
+
+ // create annotation
+ AnnotationExpr annotation = AstUtils.createAnnotationExpr(clazz, annotationName);
+ AstUtils.addAnnotation(method, annotation);
+ info("create method annotation: " + annotationName);
+
+ // print log
+ info("create method: " + methodName);
+ debug(method);
+
+ return method;
+ }
+
+ private void createMethod(
+ ClassOrInterfaceDeclaration clazz,
+ String methodName,
+ String methodContent,
+ String... importNames
+ ) {
+ // import package
+ CompilationUnit unit = AstUtils.getUnit(clazz);
+ createImport(unit, importNames);
+
+ // create method
+ MethodDeclaration method = StaticJavaParser.parseMethodDeclaration(methodContent);
+
+ // add method
+ clazz.addMember(method);
+
+ // print log
+ info("create method: " + methodName);
+ debug(method);
+ }
+
+ private void info(String info) {
+ log.info(info);
+ }
+
+ private void debug(Object object) {
+ log.debug(object.toString().replaceAll("\r\n", " ").replaceAll("\n", " "));
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/AstUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/AstUtils.java
new file mode 100644
index 0000000..527a9ff
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/AstUtils.java
@@ -0,0 +1,314 @@
+package github.plugin.unitestgen.util;
+
+import com.github.javaparser.StaticJavaParser;
+import com.github.javaparser.ast.CompilationUnit;
+import com.github.javaparser.ast.ImportDeclaration;
+import com.github.javaparser.ast.Modifier;
+import com.github.javaparser.ast.Node;
+import com.github.javaparser.ast.NodeList;
+import com.github.javaparser.ast.PackageDeclaration;
+import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
+import com.github.javaparser.ast.body.FieldDeclaration;
+import com.github.javaparser.ast.body.MethodDeclaration;
+import com.github.javaparser.ast.body.VariableDeclarator;
+import com.github.javaparser.ast.expr.AnnotationExpr;
+import com.github.javaparser.ast.expr.MemberValuePair;
+import com.github.javaparser.ast.expr.Name;
+import com.github.javaparser.ast.expr.NameExpr;
+import com.github.javaparser.ast.expr.NormalAnnotationExpr;
+import com.github.javaparser.ast.expr.VariableDeclarationExpr;
+import com.github.javaparser.ast.nodeTypes.NodeWithAnnotations;
+import com.github.javaparser.ast.nodeTypes.NodeWithName;
+import com.github.javaparser.ast.nodeTypes.NodeWithSimpleName;
+import com.github.javaparser.ast.nodeTypes.modifiers.NodeWithStaticModifier;
+import com.github.javaparser.ast.stmt.ExpressionStmt;
+import com.github.javaparser.ast.type.ArrayType;
+import com.github.javaparser.ast.type.ClassOrInterfaceType;
+import com.github.javaparser.ast.type.PrimitiveType;
+import com.github.javaparser.ast.type.Type;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.BiConsumer;
+import java.util.stream.Collectors;
+
+public class AstUtils {
+
+ public static CompilationUnit getUnit(File file) {
+ try {
+ return StaticJavaParser.parse(file);
+ } catch (FileNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static CompilationUnit getUnit(ClassOrInterfaceDeclaration clazz) {
+ return clazz.findCompilationUnit()
+ .orElseThrow(() -> new RuntimeException("test class unit not exist."));
+ }
+
+ public static CompilationUnit getUnit(MethodDeclaration method) {
+ return method.findCompilationUnit()
+ .orElseThrow(() -> new RuntimeException("src method unit not exist."));
+ }
+
+ public static PackageDeclaration getPackageDeclaration(CompilationUnit unit) {
+ return unit.getPackageDeclaration()
+ .orElseThrow(() -> new RuntimeException("package not exists."));
+ }
+
+ public static String getPackageName(CompilationUnit unit) {
+ return getName(getPackageDeclaration(unit));
+ }
+
+ public static VariableDeclarator getVariableDeclarator(FieldDeclaration field) {
+ return field.findAll(VariableDeclarator.class).get(0);
+ }
+
+ public static String getClassName(CompilationUnit unit) {
+ return unit.getPrimaryTypeName()
+ .orElseThrow(() -> new RuntimeException("class name not exists."));
+ }
+
+ public static String getName(NodeWithName extends Node> type) {
+ return type.getName().asString();
+ }
+
+ public static String getName(NodeWithSimpleName extends Node> type) {
+ return type.getName().asString();
+ }
+
+ public static String getTypeName(Type type) {
+ return StringUtils.splitFirst(type.toString(), "<");
+ }
+
+ public static List getGenericTypeNames(CompilationUnit unit, Type type) {
+ List result = new ArrayList<>();
+ if (type.isClassOrInterfaceType()) {
+ ClassOrInterfaceType classOrInterfaceType = (ClassOrInterfaceType) type;
+ classOrInterfaceType.getTypeArguments().ifPresent(types -> {
+ for (Type t : types) {
+ findImportName(unit, getTypeName(t)).ifPresent(result::add);
+ result.addAll(getGenericTypeNames(unit, t));
+ }
+ });
+ }
+ return result;
+ }
+
+ public static List getImportNames(CompilationUnit unit, Type type) {
+ List result = new ArrayList<>();
+ // get type name
+ String typeName = getTypeName(type);
+
+ // add type itself
+ findImportName(unit, typeName).ifPresent(result::add);
+
+ // add type implement type
+ if ("List".equals(typeName)) {
+ // add implement type
+ result.add("java.util.ArrayList");
+ }
+ if ("Map".equals(typeName)) {
+ // add implement type
+ result.add("java.util.HashMap");
+ }
+
+ // get generic typeNames
+ List genericTypeNames = getGenericTypeNames(unit, type);
+
+ // add generic type
+ result.addAll(genericTypeNames);
+
+ return result;
+ }
+
+ public static ClassOrInterfaceType getClassOrInterfaceType(Type type) {
+ if (type.isPrimitiveType()) {
+ return ((PrimitiveType) type).toBoxedType();
+ } else {
+ return (ClassOrInterfaceType) type;
+ }
+ }
+
+ public static Optional findImportDeclaration(
+ CompilationUnit unit,
+ String classTypeName
+ ) {
+ return unit.getImports()
+ .stream()
+ .filter(importDeclaration -> {
+ String importPackageName = getName(importDeclaration);
+ String importClassTypeName = StringUtils.splitLast(importPackageName, "\\.");
+ return classTypeName.equals(importClassTypeName);
+ })
+ .findFirst();
+ }
+
+ public static Optional findImportName(
+ CompilationUnit unit,
+ String classTypeName
+ ) {
+ return findImportDeclaration(unit, classTypeName).map(NodeWithName::getNameAsString);
+ }
+
+ private static String getDefaultValue(Type type) {
+ if (type.isArrayType()) {
+ return "new " + ((ArrayType) type).getComponentType().asString() + "[]{}";
+ } else {
+ ClassOrInterfaceType classType = getClassOrInterfaceType(type);
+ return getDefaultValue(classType);
+ }
+ }
+
+ public static String getDefaultValue(ClassOrInterfaceType fieldType) {
+ String fieldTypeName = getName(fieldType);
+ boolean genericExists = fieldType.getTypeArguments().isPresent();
+ return FieldUtils.createDefaultValue(fieldTypeName, genericExists);
+ }
+
+ public static Type toUnboxedType(Type type) {
+ if (type.isClassOrInterfaceType()) {
+ ClassOrInterfaceType classType = (ClassOrInterfaceType) type;
+ if (classType.isBoxedType()) {
+ return classType.toUnboxedType();
+ }
+ }
+ return type;
+ }
+
+ public static Optional findClassOrInterfaceType(FieldDeclaration field) {
+ if (field.getElementType().isClassOrInterfaceType()) {
+ return Optional.of(field.getElementType().asClassOrInterfaceType());
+ }
+
+ if (field.getElementType().isPrimitiveType()) {
+ return Optional.of(field.getElementType().asPrimitiveType().toBoxedType());
+ }
+
+ return Optional.empty();
+ }
+
+ public static void addImport(CompilationUnit unit, String importName) {
+ unit.addImport(importName);
+ }
+
+ public static void addAnnotation(NodeWithAnnotations extends Node> node, AnnotationExpr annotation) {
+ node.addAnnotation(annotation);
+ }
+
+ public static NodeList createMemberValues(List memberValuePairs) {
+ NodeList result = new NodeList<>();
+ result.addAll(memberValuePairs);
+ return result;
+ }
+
+ public static NodeList createMemberValues(Map nameValuePairs) {
+ List memberValuePairs = nameValuePairs.entrySet()
+ .stream()
+ .map((entry) -> {
+ String name = entry.getKey();
+ String value = entry.getValue();
+ return new MemberValuePair(name, new NameExpr(value));
+ })
+ .collect(Collectors.toList());
+ return createMemberValues(memberValuePairs);
+ }
+
+ public static NodeList createMemberValues(String name, String value) {
+ Map nameValuePairs = new HashMap<>();
+ nameValuePairs.put(name, value);
+ return createMemberValues(nameValuePairs);
+ }
+
+ public static AnnotationExpr createAnnotationExpr(
+ ClassOrInterfaceDeclaration classDeclaration,
+ String annotationFullClassName,
+ NodeList memberValuePairs
+ ) {
+ CompilationUnit unit = AstUtils.getUnit(classDeclaration);
+ unit.addImport(annotationFullClassName);
+
+ String annotationName = StringUtils.splitLast(annotationFullClassName, "\\.");
+ return new NormalAnnotationExpr(new Name(annotationName), memberValuePairs);
+ }
+
+ public static AnnotationExpr createAnnotationExpr(
+ ClassOrInterfaceDeclaration clazz,
+ String annotationFullClassName
+ ) {
+ return createAnnotationExpr(clazz, annotationFullClassName, new NodeList<>());
+ }
+
+ public static MethodDeclaration createMethodDeclaration(
+ ClassOrInterfaceDeclaration classDeclaration,
+ String methodName
+ ) {
+ return classDeclaration.addMethod(methodName, Modifier.Keyword.PUBLIC);
+ }
+
+ public static VariableDeclarator createVariableDeclarator(
+ Type fieldType,
+ String fieldName,
+ String defaultValue
+ ) {
+ VariableDeclarator variableDeclarator = new VariableDeclarator();
+ variableDeclarator.setName(fieldName);
+ variableDeclarator.setType(fieldType);
+ variableDeclarator.setInitializer(defaultValue);
+ return variableDeclarator;
+ }
+
+ public static ExpressionStmt createExpressionStmt(VariableDeclarationExpr variableExpr) {
+ return new ExpressionStmt(variableExpr);
+ }
+
+ public static VariableDeclarationExpr createVariableDeclarationExpr(
+ Type type,
+ String fieldName
+ ) {
+ VariableDeclarationExpr variableDeclarationExpr = new VariableDeclarationExpr();
+ NodeList variableDeclarators = new NodeList<>();
+ String defaultValue = AstUtils.getDefaultValue(type);
+ Type unboxedType = toUnboxedType(type);
+ VariableDeclarator variableDeclarator = createVariableDeclarator(unboxedType, fieldName, defaultValue);
+ variableDeclarators.add(variableDeclarator);
+ variableDeclarationExpr.setVariables(variableDeclarators);
+ return variableDeclarationExpr;
+ }
+
+ public static void consumeField(
+ CompilationUnit unit,
+ BiConsumer consumer
+ ) {
+ unit.findAll(FieldDeclaration.class).stream()
+ .filter(declaration -> !declaration.isFinal())
+ .forEach(
+ declaration -> AstUtils.findClassOrInterfaceType(declaration)
+ .ifPresent(fieldType -> consumer.accept(declaration, fieldType))
+ );
+ }
+
+ @SuppressWarnings("BooleanMethodIsAlwaysInverted")
+ public static boolean checkFieldExists(ClassOrInterfaceDeclaration testClass, String fieldName) {
+ return testClass.getFieldByName(fieldName).isPresent();
+ }
+
+ public static boolean checkMethodExists(ClassOrInterfaceDeclaration testClass, String methodName) {
+ return !testClass.getMethodsByName(methodName).isEmpty();
+ }
+
+ public static boolean checkVoidReturn(MethodDeclaration method) {
+ return "void".equals(method.getType().toString());
+ }
+
+ public static boolean checkUtilClass(CompilationUnit srcUnit) {
+ return srcUnit.findAll(MethodDeclaration.class).stream().anyMatch(NodeWithStaticModifier::isStatic);
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/ExprUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/ExprUtils.java
new file mode 100644
index 0000000..9e9d7fd
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/ExprUtils.java
@@ -0,0 +1,69 @@
+package github.plugin.unitestgen.util;
+
+public class ExprUtils {
+
+ public static String replaceEquals(String expression) {
+ return expression.replaceAll("==", "_equals_");
+ }
+
+ public static String replaceBracket(String expression) {
+ return expression.replaceAll("\\(", "_").replaceAll("\\)", "_");
+ }
+
+ public static String replaceNot(String expression) {
+ return expression.replaceAll("!", "_not_");
+ }
+
+ public static String replaceAnd(String expression) {
+ return expression.replaceAll("&&", "_and_");
+ }
+
+ public static String replaceOr(String expression) {
+ return expression.replaceAll("\\|\\|", "_or_");
+ }
+
+ public static String replaceComma(String expression) {
+ return expression.replaceAll(",", "_and_");
+ }
+
+ public static String replaceQuote(String expression) {
+ return expression.replaceAll("\"", "");
+ }
+
+ public static String replaceDot(String expression) {
+ return expression.replaceAll("\\.", "_");
+ }
+
+ public static String replaceSpace(String expression) {
+ return expression.replaceAll(" ", "");
+ }
+
+ public static String replaceDoubleUnderLine(String expression) {
+ return expression.replaceAll("__", "_");
+ }
+
+ public static String replaceLastUnderline(String expression) {
+ int index = expression.length() - 1;
+ if (expression.charAt(index) == '_') {
+ return expression.substring(0, expression.length() - 1);
+ }
+ return expression;
+ }
+
+ public static String expression(String expression) {
+ expression = replaceSpace(expression);
+ expression = replaceQuote(expression);
+ expression = replaceEquals(expression);
+ expression = replaceNot(expression);
+ expression = replaceAnd(expression);
+ expression = replaceOr(expression);
+ expression = replaceComma(expression);
+ expression = replaceBracket(expression);
+ expression = replaceDot(expression);
+ expression = NameUtils.toUnderscoreCase(expression);
+ expression = replaceDoubleUnderLine(expression);
+ expression = replaceLastUnderline(expression);
+ return expression;
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FieldUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FieldUtils.java
new file mode 100644
index 0000000..40876e6
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FieldUtils.java
@@ -0,0 +1,32 @@
+package github.plugin.unitestgen.util;
+
+public class FieldUtils {
+
+ public static String createDefaultValue(String typeName, boolean genericExists) {
+ switch (typeName) {
+ case "Integer":
+ return "0";
+ case "Long":
+ return "0L";
+ case "Double":
+ return "0.0";
+ case "Character":
+ return "' '";
+ case "Boolean":
+ return "false";
+ case "String":
+ return "\"\"";
+ case "List":
+ return "new ArrayList<>()";
+ case "Map":
+ return "new HashMap<>()";
+ default:
+ if (genericExists) {
+ return "new " + typeName + "<>()";
+ } else {
+ return "new " + typeName + "()";
+ }
+ }
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FileUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FileUtils.java
new file mode 100644
index 0000000..375610a
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/FileUtils.java
@@ -0,0 +1,42 @@
+package github.plugin.unitestgen.util;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Objects;
+import java.util.function.Consumer;
+
+public class FileUtils {
+
+ public static void createParentDir(File file) {
+ File parentFile = file.getParentFile();
+ if (parentFile.mkdirs()) {
+ System.out.println("create directory: " + parentFile);
+ }
+ }
+
+ public static void walk(File rootFile, String fileExt, Consumer consumer) {
+ Arrays.stream(Objects.requireNonNull(rootFile.listFiles()))
+ .filter(file -> file.getName().endsWith("." + fileExt) || file.isDirectory())
+ .forEach(file -> {
+ if (file.isDirectory()) {
+ walk(file, fileExt, consumer);
+ } else {
+ consumer.accept(file);
+ }
+ });
+ }
+
+ public static void output(File file, String content) {
+ // create parent dir
+ createParentDir(file);
+
+ // create output file
+ try (FileOutputStream fos = new FileOutputStream(file)) {
+ fos.write(content.getBytes(StandardCharsets.UTF_8));
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/NameUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/NameUtils.java
new file mode 100644
index 0000000..4e670bc
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/NameUtils.java
@@ -0,0 +1,39 @@
+package github.plugin.unitestgen.util;
+
+public class NameUtils {
+ public static String toUnderscoreCase(String camelCase) {
+ StringBuilder underscoreCase = new StringBuilder();
+ for (int i = 0; i < camelCase.length(); i++) {
+ char cur = camelCase.charAt(i);
+ char pre;
+ if (i != 0) {
+ pre = camelCase.charAt(i - 1);
+ } else {
+ pre = 'A';
+ }
+
+ if (Character.isUpperCase(cur)) {
+ if (Character.isUpperCase(pre) || pre == '_') {
+ underscoreCase.append(Character.toLowerCase(cur));
+ } else {
+ underscoreCase.append("_").append(Character.toLowerCase(cur));
+ }
+
+ } else {
+ underscoreCase.append(cur);
+ }
+ }
+ return underscoreCase.toString();
+ }
+
+ public static void main(String[] args) {
+ String expression = toUnderscoreCase("MODE_OVERWRITE.equals(mode)");
+ System.out.println(expression);
+ }
+
+ public static String toCamelCase(String pascalCase) {
+ return Character.toLowerCase(pascalCase.charAt(0)) +
+ pascalCase.substring(1).trim();
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/StringUtils.java b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/StringUtils.java
new file mode 100644
index 0000000..929fc5a
--- /dev/null
+++ b/unitestgen-maven-plugin/src/main/java/github/plugin/unitestgen/util/StringUtils.java
@@ -0,0 +1,57 @@
+package github.plugin.unitestgen.util;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+public class StringUtils {
+
+ public static Optional extract(String str, String regex) {
+ Pattern pattern = Pattern.compile(regex);
+ Matcher matcher = pattern.matcher(str);
+ if (matcher.find()) {
+ String extract = matcher.group(1);
+ return Optional.of(extract);
+ } else {
+ return Optional.empty();
+ }
+ }
+
+ public static boolean isBlank(String str) {
+ return (str == null || str.trim().isEmpty());
+ }
+
+ public static List split(String str, String regex) {
+ String[] split = str.split(regex);
+ return Arrays.asList(split);
+ }
+
+ public static String splitLast(String str, String regex) {
+ List split = split(str, regex);
+ if (split.isEmpty()) {
+ return "";
+ }
+ return split.get(split.size() - 1);
+ }
+
+ public static String splitFirst(String str, String regex) {
+ List split = split(str, regex);
+ if (split.isEmpty()) {
+ return "";
+ }
+ return split.get(0);
+ }
+
+ public static boolean includeStartsWith(String includeStr, String str) {
+ List includes = StringUtils.split(includeStr, ",");
+ for (String classFullName : includes) {
+ if (str.startsWith(classFullName)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+}
diff --git a/unitestgen-maven-plugin/src/test/java/github/plugin/unitestgen/tool/GeneToolTest.java b/unitestgen-maven-plugin/src/test/java/github/plugin/unitestgen/tool/GeneToolTest.java
new file mode 100644
index 0000000..3c8fb4c
--- /dev/null
+++ b/unitestgen-maven-plugin/src/test/java/github/plugin/unitestgen/tool/GeneToolTest.java
@@ -0,0 +1,28 @@
+//package github.plugin.unitestgen.tool;
+//
+//import org.apache.maven.plugin.logging.Log;
+//import org.apache.maven.plugin.logging.SystemStreamLog;
+//import org.junit.Test;
+//
+//import java.io.File;
+//
+//import static org.junit.Assert.*;
+//
+//public class GeneToolTest {
+//
+// @Test
+// public void generate() {
+// String mock = "";
+//// File srcFile = new File( "E:\\workspace\\unitestgen\\unitestgen-maven-plugin\\src\\main\\java\\github\\plugin\\unitestgen\\tool\\GeneTool.java");
+//// File testFile = new File("E:\\workspace\\unitestgen\\unitestgen-maven-plugin\\src\\main\\java\\github\\plugin\\unitestgen\\tool\\GeneToolAutoTest.java");
+//
+//// File srcFile = new File( "E:\\workspace\\unitestgen\\unitestgen-maven-plugin\\src\\main\\java\\github\\plugin\\unitestgen\\mojo\\GeneMojo.java");
+//// File testFile = new File("E:\\workspace\\unitestgen\\unitestgen-maven-plugin\\src\\main\\java\\github\\plugin\\unitestgen\\mojo\\GeneMojoAutoTest.java");
+//
+// File srcFile = new File( "E:\\workspace\\unitestgen\\unitestgen-sample\\src\\main\\java\\github\\plugin\\unitestgen\\repository\\ParseRepository.java");
+// File testFile = new File("E:\\workspace\\unitestgen\\unitestgen-sample\\src\\test\\java\\github\\plugin\\unitestgen\\repository\\ParseRepositoryAutoTest.java");
+//
+// Log log = new SystemStreamLog();
+// new GeneTool(mock, srcFile, testFile, false, log).generate();
+// }
+//}
diff --git a/unitestgen-sample/pom.xml b/unitestgen-sample/pom.xml
new file mode 100644
index 0000000..dc1c2b9
--- /dev/null
+++ b/unitestgen-sample/pom.xml
@@ -0,0 +1,40 @@
+
+
+ 4.0.0
+
+
+ org.springframework.boot
+ spring-boot-starter-parent
+ 2.7.10
+
+
+
+ github.plugin
+ unitestgen-sample
+ 1.0
+ jar
+
+
+ UTF-8
+
+
+
+
+ org.springframework.boot
+ spring-boot-starter-web
+
+
+ org.springframework.boot
+ spring-boot-starter-test
+ test
+
+
+ junit
+ junit
+ test
+
+
+
diff --git a/unitestgen-sample/src/main/java/github/plugin/unitestgen/App.java b/unitestgen-sample/src/main/java/github/plugin/unitestgen/App.java
new file mode 100644
index 0000000..b52e77b
--- /dev/null
+++ b/unitestgen-sample/src/main/java/github/plugin/unitestgen/App.java
@@ -0,0 +1,13 @@
+package github.plugin.unitestgen;
+
+/**
+ * Hello world!
+ *
+ */
+public class App
+{
+ public static void main( String[] args )
+ {
+ System.out.println( "Hello World!" );
+ }
+}
diff --git a/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel.java b/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel.java
new file mode 100644
index 0000000..e0d89ac
--- /dev/null
+++ b/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel.java
@@ -0,0 +1,74 @@
+package github.plugin.unitestgen.model;
+
+import java.util.Date;
+
+public class ParseModel {
+ private String username;
+
+ private Integer int1;
+ private Integer int2;
+
+ private Date date;
+
+ private Boolean bool;
+
+ private boolean bool2;
+
+ private ParseModel parse;
+
+ public String getUsername() {
+ return username;
+ }
+
+ public void setUsername(String username) {
+ this.username = username;
+ }
+
+ public Integer getInt1() {
+ return int1;
+ }
+
+ public void setInt1(Integer int1) {
+ this.int1 = int1;
+ }
+
+ public Integer getInt2() {
+ return int2;
+ }
+
+ public void setInt2(Integer int2) {
+ this.int2 = int2;
+ }
+
+ public Date getDate() {
+ return date;
+ }
+
+ public void setDate(Date date) {
+ this.date = date;
+ }
+
+ public Boolean getBool() {
+ return bool;
+ }
+
+ public void setBool(Boolean bool) {
+ this.bool = bool;
+ }
+
+ public boolean isBool2() {
+ return bool2;
+ }
+
+ public void setBool2(boolean bool2) {
+ this.bool2 = bool2;
+ }
+
+ public ParseModel getParse() {
+ return parse;
+ }
+
+ public void setParse(ParseModel parse) {
+ this.parse = parse;
+ }
+}
diff --git a/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel2.java b/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel2.java
new file mode 100644
index 0000000..eb9304a
--- /dev/null
+++ b/unitestgen-sample/src/main/java/github/plugin/unitestgen/model/ParseModel2.java
@@ -0,0 +1,74 @@
+package github.plugin.unitestgen.model;
+
+import java.util.Date;
+
+public class ParseModel2 {
+ private String username;
+
+ private Integer int1;
+ private Integer int2;
+
+ private Date date;
+
+ private Boolean bool;
+
+ private boolean bool2;
+
+ private ParseModel2 parse;
+
+ public String getUsername() {
+ return username;
+ }
+
+ public void setUsername(String username) {
+ this.username = username;
+ }
+
+ public Integer getInt1() {
+ return int1;
+ }
+
+ public void setInt1(Integer int1) {
+ this.int1 = int1;
+ }
+
+ public Integer getInt2() {
+ return int2;
+ }
+
+ public void setInt2(Integer int2) {
+ this.int2 = int2;
+ }
+
+ public Date getDate() {
+ return date;
+ }
+
+ public void setDate(Date date) {
+ this.date = date;
+ }
+
+ public Boolean getBool() {
+ return bool;
+ }
+
+ public void setBool(Boolean bool) {
+ this.bool = bool;
+ }
+
+ public boolean isBool2() {
+ return bool2;
+ }
+
+ public void setBool2(boolean bool2) {
+ this.bool2 = bool2;
+ }
+
+ public ParseModel2 getParse() {
+ return parse;
+ }
+
+ public void setParse(ParseModel2 parse) {
+ this.parse = parse;
+ }
+}
diff --git a/unitestgen-sample/src/main/java/github/plugin/unitestgen/repository/ParseRepository.java b/unitestgen-sample/src/main/java/github/plugin/unitestgen/repository/ParseRepository.java
new file mode 100644
index 0000000..bb18fdd
--- /dev/null
+++ b/unitestgen-sample/src/main/java/github/plugin/unitestgen/repository/ParseRepository.java
@@ -0,0 +1,42 @@
+package github.plugin.unitestgen.repository;
+
+import github.plugin.unitestgen.App;
+import github.plugin.unitestgen.model.ParseModel;
+import jdk.nashorn.internal.runtime.arrays.ArrayIndex;
+import jdk.nashorn.internal.runtime.linker.Bootstrap;
+
+import java.text.Format;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+
+public class ParseRepository {
+
+ public ParseModel selectById(String str) {
+ return null;
+ }
+
+ public ParseModel selectById(Integer str) {
+ return null;
+ }
+
+
+ public List selectById(Map> options) {
+ if (options.isEmpty()) {
+ return new ArrayList<>();
+ } else {
+ return Collections.singletonList(new ParseModel());
+ }
+ }
+
+ public List