JUnitParameterizedRunner.java
import org.junit.Test;
import org.junit.runner.Description;
import org.junit.runner.notification.Failure;
import org.junit.runner.notification.RunListener;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.Statement;
import org.junit.runners.model.TestClass;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
/**
* JUnitParameterizedRunner
* Describe: The class is a runner of JUnit, It supports parametric testing.
*
* Example:
* @RunWith(JUnitParameterizedRunner.class)
* public class Test{
* @ParametersData(name = "testData")
* public static Object[][] data() {
* return new Object[][]{
* {1, 2, 3},
* {4, 5, 9}
* }
* }
*
* @Test
* @Parameters(name = "testData") //@Parameters(name = "testData", dataClass=Test.class)
* public void test(int a, int b, int expected) {
* Assert.assertEquals(expected, add(a, b));
* }
*
* public int add(int a, int b) {
* return a + b;
* }
* }
*/
public class JUnitParameterizedRunner extends BlockJUnit4ClassRunner {
private static final Logger logger = LoggerFactory.getLogger(UnitTestCaseBase.class);
private final Map<String, Object[][]> parametersMap = new HashMap<>();
private final Map<Method, Integer> getParamsIndexMap = new HashMap<>();
public JUnitParameterizedRunner(Class<?> klass) throws InitializationError {
super(klass);
}
@Override
protected Statement methodInvoker(FrameworkMethod method, Object test) {
Object[] fParameters = null;
if (method.getMethod().getParameterCount() > 0) {
fParameters = getParameters(method);
}
return new ParameterInvokeMethod(method, test, fParameters);
}
@Override
protected void validateTestMethods(List<Throwable> errors) {
validatePublicVoidMethods(Test.class, false, errors);
}
@Override
protected List<FrameworkMethod> getChildren() {
List<FrameworkMethod> list = super.getChildren();
List<FrameworkMethod> resultList = new ArrayList<>();
for (int i = 0; i < list.size(); i++) {
FrameworkMethod method = list.get(i);
if (method.getMethod().getParameterCount() > 0) {
Object[][] params = getParametersList(method);
if (params != null) {
for (int j = 0; j < params.length; j++) {
resultList.add(method);
}
}
} else {
resultList.add(method);
}
}
return resultList;
}
@Override
protected void runChild(FrameworkMethod method, RunNotifier notifier) {
Method m = method.getMethod();
String paramTypes = Arrays.toString(m.getParameterTypes());
String methodDesc = String.format("%s#%s(%s)",
m.getDeclaringClass().getTypeName(), m.getName(), paramTypes.substring(1, paramTypes.length()-1));
RunListener runListener = new RunListener() {
@Override
public void testFailure(Failure failure) throws Exception {
logger.info(String.format("**** Case Failure: %s\n", methodDesc));
}
@Override
public void testIgnored(Description description) throws Exception {
logger.info(String.format("---- Case Ignored: %s", methodDesc));
}
};
notifier.addFirstListener(runListener);
logger.info(String.format(">>>> Case Started %s", methodDesc));
super.runChild(method, notifier);
logger.info(String.format("<<<< Case Finished: %s\n", methodDesc));
notifier.removeListener(runListener);
}
protected void validatePublicVoidMethods(Class<? extends Annotation> annotation,
boolean isStatic, List<Throwable> errors) {
List<FrameworkMethod> methods= getTestClass().getAnnotatedMethods(annotation);
for (FrameworkMethod eachTestMethod : methods)
eachTestMethod.validatePublicVoid(isStatic, errors);
}
private void loadParameters(Class clazz) {
TestClass testClass;
if(clazz != Object.class) {
testClass = new TestClass(clazz);
} else {
testClass = getTestClass();
}
List<FrameworkMethod> methods = testClass.getAnnotatedMethods(ParametersData.class);
for (FrameworkMethod method : methods) {
int modifiers = method.getMethod().getModifiers();
if (Modifier.isStatic(modifiers) && Modifier.isPublic(modifiers)) {
ParametersData parametersProvider = method.getAnnotation(ParametersData.class);
try {
String key = getParamsKey(clazz, parametersProvider.name());
parametersMap.put(key, (Object[][]) method.invokeExplosively(null));
} catch (Throwable throwable) {
throwable.printStackTrace();
throw new RunParametersProviderException(String.format("run %s failed, method:%s,exception:%s:%s",
parametersProvider.name(), method.getName(), throwable, throwable.getMessage()));
}
}
}
}
private String getParamsKey(Class clazz, String name) {
return String.format("@Parameters(%s, %s)", clazz.getName(), name);
}
private Object[][] getParametersList(FrameworkMethod method) {
Parameters parameters = method.getAnnotation(Parameters.class);
if (parameters != null) {
String key = getParamsKey(parameters.dataClass(), parameters.name());
Object[][] data = parametersMap.get(key);
if(data == null) {
loadParameters(parameters.dataClass());
}
data = parametersMap.get(key);
if(data != null && data.length > 0 && data[0].length == method.getMethod().getParameterCount()) {
return data;
}
throw new InvalidParametersConfigException(String.format("Wrong %s about %s#%s method.",
key, method.getMethod().getDeclaringClass().getTypeName(), method.getName()));
} else {
throw new NotParametersConfigException(String.format(
"%s#%s method need to configure the @Parameters annotation.",
method.getMethod().getDeclaringClass().getTypeName(), method.getName()));
}
}
private Object[] getParameters(FrameworkMethod method) {
Object[] params = null;
Object[][] parametersList = getParametersList(method);
if (parametersList != null) {
Integer index = getParamsIndexMap.get(method.getMethod());
if (index == null) {
index = 0;
}
if (index < parametersList.length) {
params = parametersList[index];
getParamsIndexMap.put(method.getMethod(), ++index);
}
}
return params;
}
private class ParameterInvokeMethod extends Statement {
private final FrameworkMethod fTestMethod;
private Object fTarget;
private Object[] args;
public ParameterInvokeMethod(FrameworkMethod testMethod, Object target, Object[] params) {
fTestMethod = testMethod;
fTarget = target;
args = params;
}
@Override
public void evaluate() throws Throwable {
fTestMethod.invokeExplosively(fTarget, args);
}
}
public static class RunParametersProviderException extends RuntimeException {
public RunParametersProviderException(String msg) {
super(msg);
}
}
public static class NotParametersConfigException extends RuntimeException {
public NotParametersConfigException(String msg) {
super(msg);
}
}
public static class InvalidParametersConfigException extends RuntimeException {
public InvalidParametersConfigException(String msg) {
super(msg);
}
}
}