package com.codeborne.selenide.impl;
import com.codeborne.selenide.SelenideElement;
import com.codeborne.selenide.commands.Commands;
import com.codeborne.selenide.ex.InvalidStateException;
import com.codeborne.selenide.ex.UIAssertionError;
import com.codeborne.selenide.logevents.SelenideLog;
import com.codeborne.selenide.logevents.SelenideLogger;
import org.openqa.selenium.InvalidElementStateException;
import org.openqa.selenium.WebDriverException;
import java.io.FileNotFoundException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Set;
import static com.codeborne.selenide.Condition.exist;
import static com.codeborne.selenide.Configuration.AssertionMode.SOFT;
import static com.codeborne.selenide.Configuration.*;
import static com.codeborne.selenide.Selenide.sleep;
import static com.codeborne.selenide.logevents.ErrorsCollector.validateAssertionMode;
import static com.codeborne.selenide.logevents.LogEvent.EventStatus.PASS;
import static java.lang.System.currentTimeMillis;
import static java.util.Arrays.asList;
class SelenideElementProxy implements InvocationHandler {
private static final Set<String> methodsToSkipLogging = new HashSet<>(asList(
"toWebElement",
"toString"
));
private static final Set<String> methodsForSoftAssertion = new HashSet<>(asList(
"should",
"shouldBe",
"shouldHave",
"shouldNot",
"shouldNotHave",
"shouldNotBe",
"waitUntil",
"waitWhile"
));
private final WebElementSource webElementSource;
protected SelenideElementProxy(WebElementSource webElementSource) {
this.webElementSource = webElementSource;
}
@Override
public Object invoke(Object proxy, Method method, Object... args) throws Throwable {
if (methodsToSkipLogging.contains(method.getName()))
return Commands.getInstance().execute(proxy, webElementSource, method.getName(), args);
validateAssertionMode();
long timeoutMs = getTimeoutMs(method, args);
long pollingIntervalMs = getPollingIntervalMs(method, args);
SelenideLog log = SelenideLogger.beginStep(webElementSource.getSearchCriteria(), method.getName(), args);
try {
Object result = dispatchAndRetry(timeoutMs, pollingIntervalMs, proxy, method, args);
SelenideLogger.commitStep(log, PASS);
return result;
}
catch (Error error) {
SelenideLogger.commitStep(log, error);
if (assertionMode == SOFT && methodsForSoftAssertion.contains(method.getName()))
return proxy;
else
throw UIAssertionError.wrap(error, timeoutMs);
}
catch (RuntimeException error) {
SelenideLogger.commitStep(log, error);
throw error;
}
}
protected Object dispatchAndRetry(long timeoutMs, long pollingIntervalMs,
Object proxy, Method method, Object[] args) throws Throwable, Error {
final long startTime = currentTimeMillis();
Throwable lastError;
do {
try {
if (SelenideElement.class.isAssignableFrom(method.getDeclaringClass())) {
return Commands.getInstance().execute(proxy, webElementSource, method.getName(), args);
}
return method.invoke(webElementSource.getWebElement(), args);
}
catch (InvocationTargetException e) {
lastError = e.getTargetException();
}
catch (Throwable e) {
lastError = e;
}
if (Cleanup.of.isInvalidSelectorError(lastError)) {
throw Cleanup.of.wrap(lastError);
}
else if (!shouldRetryAfterError(lastError)) {
throw lastError;
}
sleep(pollingIntervalMs);
}
while (currentTimeMillis() - startTime <= timeoutMs);
if (lastError instanceof UIAssertionError) {
throw lastError;
}
else if (lastError instanceof InvalidElementStateException) {
throw new InvalidStateException(lastError);
}
else if (lastError instanceof WebDriverException) {
throw webElementSource.createElementNotFoundError(exist, lastError);
}
throw lastError;
}
static boolean shouldRetryAfterError(Throwable e) {
if (e instanceof FileNotFoundException) return false;
if (e instanceof IllegalArgumentException) return false;
if (e instanceof ReflectiveOperationException) return false;
return e instanceof Exception || e instanceof AssertionError;
}
private long getTimeoutMs(Method method, Object[] args) {
return isWaitCommand(method) ?
args.length == 3 ? (Long) args[args.length - 2] : (Long) args[args.length - 1] :
timeout;
}
private long getPollingIntervalMs(Method method, Object[] args) {
return isWaitCommand(method) && args.length == 3 ? (Long) args[args.length - 1] : pollingInterval;
}
private boolean isWaitCommand(Method method) {
return "waitUntil".equals(method.getName()) || "waitWhile".equals(method.getName());
}
}