Skip to content
Snippets Groups Projects
Commit b51fad11 authored by Felix Sidokhine's avatar Felix Sidokhine
Browse files

replaced abastract class extensions with simple annotation

Change-Id: I705fba61730eca602999d7482834d1dd1969c89e
parent 49794f37
No related branches found
No related tags found
No related merge requests found
Showing
with 138 additions and 97 deletions
...@@ -35,7 +35,6 @@ import net.jami.jams.common.objects.system.SystemAccountType; ...@@ -35,7 +35,6 @@ import net.jami.jams.common.objects.system.SystemAccountType;
import net.jami.jams.common.objects.user.User; import net.jami.jams.common.objects.user.User;
import net.jami.jams.common.utils.X509Utils; import net.jami.jams.common.utils.X509Utils;
import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
...@@ -43,8 +42,6 @@ import org.junit.jupiter.api.Test; ...@@ -43,8 +42,6 @@ import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.InputStream; import java.io.InputStream;
import java.math.BigInteger; import java.math.BigInteger;
import java.nio.file.Path;
import java.nio.file.Paths;
class SystemAccountBuilderTest { class SystemAccountBuilderTest {
......
package net.jami.jams.common.annotations;
import javassist.ClassPool;
import javassist.CtClass;
import javassist.CtMethod;
import javassist.bytecode.AnnotationsAttribute;
import javassist.bytecode.MethodInfo;
import javassist.bytecode.annotation.Annotation;
import javassist.bytecode.annotation.ArrayMemberValue;
import javassist.bytecode.annotation.EnumMemberValue;
import lombok.extern.slf4j.Slf4j;
import net.jami.jams.common.objects.user.AccessLevel;
import java.util.HashSet;
@Slf4j
public class ScopedServletAnnotationScanner {
public void scanAndModify(HashSet<String> classes){
classes.parallelStream().forEach(this::processClass);
}
public void processClass(String classname){
try {
CtClass cc = ClassPool.getDefault().get(classname);
cc.defrost();
CtMethod[] ctMethods = cc.getMethods();
boolean classChanged = false;
for (int i = 0; i < ctMethods.length; i++) {
MethodInfo minfo = ctMethods[i].getMethodInfo();
if(minfo.getAttribute(AnnotationsAttribute.visibleTag) == null){
continue;
}
AnnotationsAttribute attr = (AnnotationsAttribute) minfo.getAttribute(AnnotationsAttribute.visibleTag);
Annotation[] an = attr.getAnnotations();
for(int j=0; j<an.length; j++) {
if(an[j].getTypeName().equals(ScopedServletMethod.class.getName())) {
ArrayMemberValue memberValue = (ArrayMemberValue) an[j].getMemberValue("securityGroups");
HashSet<AccessLevel> accessLevels = new HashSet<>();
for(int k =0; k<memberValue.getValue().length; k++){
EnumMemberValue level = (EnumMemberValue) memberValue.getValue()[k];
AccessLevel accessLevel = AccessLevel.valueOf(level.getValue());
classChanged = true;
accessLevels.add(accessLevel);
}
log.info("Detected scoped servlet, modifying method accordingly [" + classname + "]");
//Build the code block that enforces security.
StringBuilder sb = new StringBuilder();
sb.append("{\n");
//So this does not play nice when trying to use hash sets...
sb.append("boolean allowed = false;\n");
sb.append("net.jami.jams.common.objects.user.AccessLevel level = (net.jami.jams.common.objects.user.AccessLevel) req.getAttribute(\"accessLevel\");\n");
accessLevels.forEach(e -> {
sb.append("if(level == net.jami.jams.common.objects.user.AccessLevel.valueOf(\"").append(e.toString()).append("\")) allowed = true;\n");
});
sb.append("if(!allowed){\n");
sb.append("resp.sendError(403,\"No valid access level found!\");\n");
sb.append("return;\n");
sb.append("}\n");
sb.append("}\n");
ctMethods[i].insertBefore(sb.toString());
}
}
}
if(classChanged){
try {
if (cc.isFrozen()) cc.defrost();
cc.writeFile(".");
cc.toClass();
log.info("Successfully modified class " + classname);
}
catch (Exception e){
log.error("Could not persist changes to class " + classname);
}
}
} catch (Exception e) {
//log.info("Could not modify a target class with error {}",e.getMessage());
}
}
}
package net.jami.jams.common.annotations;
import net.jami.jams.common.objects.user.AccessLevel;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface ScopedServletMethod {
public AccessLevel[] securityGroups();
}
package net.jami.jams.common.servlets;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.Getter;
import lombok.Setter;
import net.jami.jams.common.objects.user.AccessLevel;
import java.io.IOException;
import java.util.HashSet;
/*
This is in order to be able to manipulate this more efficiently.
*/
public abstract class ScopedServlet extends HttpServlet {
public static HashSet<AccessLevel> GET_accessLevels = new HashSet<>();
public static HashSet<AccessLevel> PUT_accessLevels = new HashSet<>();
public static HashSet<AccessLevel> DELETE_accessLevels = new HashSet<>();
public static HashSet<AccessLevel> POST_accessLevels = new HashSet<>();
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
if(!(req.getAttribute("accessLevel") instanceof AccessLevel)){
resp.sendError(403,"No valid access level found!");
return;
}
if(!GET_accessLevels.contains((AccessLevel) req.getAttribute("accessLevel"))){
resp.sendError(403,"You do not have enough rights to access this endpoint!");
return;
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
if(!(req.getAttribute("accessLevel") instanceof AccessLevel)){
resp.sendError(403,"No valid access level found!");
return;
}
if(!POST_accessLevels.contains((AccessLevel) req.getAttribute("accessLevel"))){
resp.sendError(403,"You do not have enough rights to access this endpoint!");
return;
}
}
@Override
protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
if(!(req.getAttribute("accessLevel") instanceof AccessLevel)){
resp.sendError(403,"No valid access level found!");
return;
}
if(!PUT_accessLevels.contains((AccessLevel) req.getAttribute("accessLevel"))){
resp.sendError(403,"You do not have enough rights to access this endpoint!");
return;
}
}
@Override
protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
if(!(req.getAttribute("accessLevel") instanceof AccessLevel)){
resp.sendError(403,"No valid access level found!");
return;
}
if(!DELETE_accessLevels.contains((AccessLevel) req.getAttribute("accessLevel"))){
resp.sendError(403,"You do not have enough rights to access this endpoint!");
return;
}
}
}
...@@ -26,6 +26,7 @@ import com.jsoniter.JsonIterator; ...@@ -26,6 +26,7 @@ import com.jsoniter.JsonIterator;
import javassist.ClassPool; import javassist.ClassPool;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.jami.datastore.main.DataStore; import net.jami.datastore.main.DataStore;
import net.jami.jams.common.annotations.ScopedServletAnnotationScanner;
import net.jami.jams.common.authentication.AuthenticationSourceType; import net.jami.jams.common.authentication.AuthenticationSourceType;
import net.jami.jams.common.authentication.local.LocalAuthSettings; import net.jami.jams.common.authentication.local.LocalAuthSettings;
import net.jami.jams.common.authmodule.AuthenticationModule; import net.jami.jams.common.authmodule.AuthenticationModule;
...@@ -41,10 +42,12 @@ import net.jami.jams.server.core.TomcatLauncher; ...@@ -41,10 +42,12 @@ import net.jami.jams.server.core.TomcatLauncher;
import net.jami.jams.server.licensing.LicenseService; import net.jami.jams.server.licensing.LicenseService;
import net.jami.jams.server.startup.AuthModuleLoader; import net.jami.jams.server.startup.AuthModuleLoader;
import net.jami.jams.server.startup.CryptoEngineLoader; import net.jami.jams.server.startup.CryptoEngineLoader;
import net.jami.jams.server.startup.PackageScanner;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.io.InputStream; import java.io.InputStream;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@Slf4j @Slf4j
...@@ -75,6 +78,14 @@ public class Server { ...@@ -75,6 +78,14 @@ public class Server {
public static void main(String[] args) { public static void main(String[] args) {
//This is a fix to drop old cached stuff from the tomcat classloader. //This is a fix to drop old cached stuff from the tomcat classloader.
ClassPool.getDefault().clearImportedPackages(); ClassPool.getDefault().clearImportedPackages();
ScopedServletAnnotationScanner scanner = new ScopedServletAnnotationScanner();
//Here we need to scan and modify our servlets, this is not necessarily the greatest thing ever.
try {
scanner.scanAndModify(PackageScanner.getClasses());
}
catch (Exception e){
log.error("Could not modify the annotated files");
}
switch (args.length) { switch (args.length) {
case 1: case 1:
tomcatLauncher = new TomcatLauncher(Integer.parseInt(args[0])); tomcatLauncher = new TomcatLauncher(Integer.parseInt(args[0]));
......
...@@ -80,6 +80,7 @@ public class TomcatLauncher { ...@@ -80,6 +80,7 @@ public class TomcatLauncher {
log.info("Serving application from: " + new File(System.getProperty("user.dir")).getAbsolutePath()); log.info("Serving application from: " + new File(System.getProperty("user.dir")).getAbsolutePath());
WebResourceRoot resources = new StandardRoot(context); WebResourceRoot resources = new StandardRoot(context);
//We probably need to scan the annotations here, because they should be loaded AFTER the classes are loaded.
if (jarName.contains(".jar")) { if (jarName.contains(".jar")) {
resources.addPreResources(new JarResourceSet(resources, "/WEB-INF/classes", jarName, "/net/jami/jams/server/servlets")); resources.addPreResources(new JarResourceSet(resources, "/WEB-INF/classes", jarName, "/net/jami/jams/server/servlets"));
resources.addPreResources(new JarResourceSet(resources, "/", jarName, "/webapp")); resources.addPreResources(new JarResourceSet(resources, "/", jarName, "/webapp"));
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
package net.jami.jams.server.core.workflows; package net.jami.jams.server.core.workflows;
import com.jsoniter.output.JsonStream; import com.jsoniter.output.JsonStream;
import javassist.ClassPool;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.jami.jams.common.authentication.AuthenticationSourceType; import net.jami.jams.common.authentication.AuthenticationSourceType;
import net.jami.jams.common.objects.roots.X509Fields; import net.jami.jams.common.objects.roots.X509Fields;
......
...@@ -29,6 +29,7 @@ import jakarta.servlet.annotation.WebServlet; ...@@ -29,6 +29,7 @@ import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet; import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import net.jami.jams.common.annotations.ScopedServletMethod;
import net.jami.jams.common.dao.StatementElement; import net.jami.jams.common.dao.StatementElement;
import net.jami.jams.common.dao.StatementList; import net.jami.jams.common.dao.StatementList;
import net.jami.jams.common.objects.devices.Device; import net.jami.jams.common.objects.devices.Device;
...@@ -36,10 +37,8 @@ import net.jami.jams.common.objects.requests.DeviceRegistrationRequest; ...@@ -36,10 +37,8 @@ import net.jami.jams.common.objects.requests.DeviceRegistrationRequest;
import net.jami.jams.common.objects.responses.DeviceRegistrationResponse; import net.jami.jams.common.objects.responses.DeviceRegistrationResponse;
import net.jami.jams.common.objects.responses.DeviceRevocationResponse; import net.jami.jams.common.objects.responses.DeviceRevocationResponse;
import net.jami.jams.common.objects.user.AccessLevel; import net.jami.jams.common.objects.user.AccessLevel;
import net.jami.jams.common.servlets.ScopedServlet;
import net.jami.jams.server.core.workflows.RegisterDeviceFlow; import net.jami.jams.server.core.workflows.RegisterDeviceFlow;
import net.jami.jams.server.core.workflows.RevokeDeviceFlow; import net.jami.jams.server.core.workflows.RevokeDeviceFlow;
import net.jami.jams.server.servlets.api.auth.contacts.ContactServlet;
import java.io.IOException; import java.io.IOException;
...@@ -47,13 +46,8 @@ import static net.jami.jams.server.Server.certificateAuthority; ...@@ -47,13 +46,8 @@ import static net.jami.jams.server.Server.certificateAuthority;
import static net.jami.jams.server.Server.dataStore; import static net.jami.jams.server.Server.dataStore;
@WebServlet("/api/auth/device/*") @WebServlet("/api/auth/device/*")
public class DeviceServlet extends ScopedServlet { public class DeviceServlet extends HttpServlet {
@Override
public void init() throws ServletException {
DeviceServlet.DELETE_accessLevels.add(AccessLevel.USER);
DeviceServlet.POST_accessLevels.add(AccessLevel.USER);
}
/** /**
* @apiVersion 1.0.0 * @apiVersion 1.0.0
* @api {get} /api/auth/device Get device info * @api {get} /api/auth/device Get device info
...@@ -112,8 +106,9 @@ public class DeviceServlet extends ScopedServlet { ...@@ -112,8 +106,9 @@ public class DeviceServlet extends ScopedServlet {
* } * }
*/ */
@Override @Override
@ScopedServletMethod(securityGroups = {AccessLevel.USER})
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doPost(req,resp); //super.doPost(req,resp);
DeviceRegistrationRequest request = JsonIterator.deserialize(req.getInputStream().readAllBytes(), DeviceRegistrationRequest.class); DeviceRegistrationRequest request = JsonIterator.deserialize(req.getInputStream().readAllBytes(), DeviceRegistrationRequest.class);
DeviceRegistrationResponse devResponse = RegisterDeviceFlow.registerDevice(req.getAttribute("username").toString(),request); DeviceRegistrationResponse devResponse = RegisterDeviceFlow.registerDevice(req.getAttribute("username").toString(),request);
if(devResponse != null) resp.getOutputStream().write(JsonStream.serialize(devResponse).getBytes()); if(devResponse != null) resp.getOutputStream().write(JsonStream.serialize(devResponse).getBytes());
...@@ -161,6 +156,7 @@ public class DeviceServlet extends ScopedServlet { ...@@ -161,6 +156,7 @@ public class DeviceServlet extends ScopedServlet {
* @apiError (500) {null} null device could not be deactivated * @apiError (500) {null} null device could not be deactivated
*/ */
@Override @Override
@ScopedServletMethod(securityGroups = {AccessLevel.USER})
protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doDelete(req,resp); super.doDelete(req,resp);
String deviceId = req.getPathInfo().replace("/",""); String deviceId = req.getPathInfo().replace("/","");
......
...@@ -31,8 +31,6 @@ import jakarta.servlet.http.HttpServletResponse; ...@@ -31,8 +31,6 @@ import jakarta.servlet.http.HttpServletResponse;
import net.jami.jams.common.dao.StatementElement; import net.jami.jams.common.dao.StatementElement;
import net.jami.jams.common.dao.StatementList; import net.jami.jams.common.dao.StatementList;
import net.jami.jams.common.objects.devices.Device; import net.jami.jams.common.objects.devices.Device;
import net.jami.jams.common.objects.user.AccessLevel;
import net.jami.jams.common.servlets.ScopedServlet;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
......
...@@ -25,15 +25,15 @@ package net.jami.jams.server.servlets.api.auth.user; ...@@ -25,15 +25,15 @@ package net.jami.jams.server.servlets.api.auth.user;
import com.jsoniter.output.JsonStream; import com.jsoniter.output.JsonStream;
import jakarta.servlet.ServletException; import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import net.jami.jams.common.annotations.ScopedServletMethod;
import net.jami.jams.common.authentication.AuthenticationSourceType; import net.jami.jams.common.authentication.AuthenticationSourceType;
import net.jami.jams.common.dao.StatementElement; import net.jami.jams.common.dao.StatementElement;
import net.jami.jams.common.dao.StatementList; import net.jami.jams.common.dao.StatementList;
import net.jami.jams.common.objects.user.AccessLevel; import net.jami.jams.common.objects.user.AccessLevel;
import net.jami.jams.common.objects.user.User; import net.jami.jams.common.objects.user.User;
import net.jami.jams.common.servlets.ScopedServlet;
import net.jami.jams.server.servlets.api.auth.device.DeviceServlet;
import java.io.IOException; import java.io.IOException;
...@@ -41,12 +41,8 @@ import static net.jami.jams.server.Server.certificateAuthority; ...@@ -41,12 +41,8 @@ import static net.jami.jams.server.Server.certificateAuthority;
import static net.jami.jams.server.Server.dataStore; import static net.jami.jams.server.Server.dataStore;
@WebServlet("/api/auth/user") @WebServlet("/api/auth/user")
public class UserServlet extends ScopedServlet { public class UserServlet extends HttpServlet {
@Override
public void init() throws ServletException {
UserServlet.PUT_accessLevels.add(AccessLevel.USER);
}
//User can "read" his own profile. //User can "read" his own profile.
/** /**
* @apiVersion 1.0.0 * @apiVersion 1.0.0
...@@ -74,7 +70,6 @@ public class UserServlet extends ScopedServlet { ...@@ -74,7 +70,6 @@ public class UserServlet extends ScopedServlet {
*/ */
@Override @Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doGet(req,resp);
String username = req.getAttribute("username").toString(); String username = req.getAttribute("username").toString();
StatementList select = new StatementList(); StatementList select = new StatementList();
StatementElement st = new StatementElement("username","=",username,""); StatementElement st = new StatementElement("username","=",username,"");
...@@ -97,8 +92,8 @@ public class UserServlet extends ScopedServlet { ...@@ -97,8 +92,8 @@ public class UserServlet extends ScopedServlet {
* @apiError (500) {null} null could not changed password * @apiError (500) {null} null could not changed password
*/ */
@Override @Override
@ScopedServletMethod(securityGroups = {AccessLevel.USER})
protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
super.doPut(req,resp);
String username = req.getAttribute("username").toString(); String username = req.getAttribute("username").toString();
//Check if he is AD/LDAP - then return a 401, because we can't set such password. //Check if he is AD/LDAP - then return a 401, because we can't set such password.
StatementList select = new StatementList(); StatementList select = new StatementList();
......
...@@ -32,7 +32,6 @@ import jakarta.servlet.ServletException; ...@@ -32,7 +32,6 @@ import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse; import jakarta.servlet.ServletResponse;
import jakarta.servlet.annotation.WebFilter; import jakarta.servlet.annotation.WebFilter;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse; import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
...@@ -40,7 +39,6 @@ import net.jami.jams.common.objects.user.AccessLevel; ...@@ -40,7 +39,6 @@ import net.jami.jams.common.objects.user.AccessLevel;
import net.jami.jams.server.Server; import net.jami.jams.server.Server;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Method;
import static net.jami.jams.server.Server.userAuthenticationModule; import static net.jami.jams.server.Server.userAuthenticationModule;
import static net.jami.jams.server.servlets.filters.JWTValidator.verifyValidity; import static net.jami.jams.server.servlets.filters.JWTValidator.verifyValidity;
......
package net.jami.jams.server.startup;
import java.io.File;
import java.io.FileInputStream;
import java.util.HashSet;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
public class PackageScanner {
public static HashSet<String> getClasses() throws Exception {
HashSet<String> classNames = new HashSet<>();
ZipInputStream zip = new ZipInputStream(new FileInputStream(System.getProperty("user.dir") + File.separator + "jams-server.jar"));
for (ZipEntry entry = zip.getNextEntry(); entry != null; entry = zip.getNextEntry()) {
if (!entry.isDirectory() && entry.getName().endsWith(".class")) {
// This ZipEntry represents a class. Now, what class does it represent?
String className = entry.getName().replace('/', '.'); // including ".class"
classNames.add(className.substring(0, className.length() - ".class".length()));
}
}
return classNames;
}
}
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment