|  1 |     | 
     | 
  |  2 |     | 
     | 
  |  3 |     | 
     | 
  |  4 |     | 
     | 
  |  5 |     | 
     | 
  |  6 |     | 
     | 
  |  7 |     | 
     | 
  |  8 |     | 
     | 
  |  9 |     | 
     | 
  |  10 |     | 
     | 
  |  11 |     | 
     | 
  |  12 |     | 
     | 
  |  13 |     | 
     | 
  |  14 |     | 
     | 
  |  15 |     | 
     | 
  |  16 |     | 
   package at.ipsquare.commons.servlet;  | 
  |  17 |     | 
     | 
  |  18 |     | 
   import java.io.IOException;  | 
  |  19 |     | 
   import java.util.Enumeration;  | 
  |  20 |     | 
   import java.util.IdentityHashMap;  | 
  |  21 |     | 
   import java.util.Iterator;  | 
  |  22 |     | 
   import java.util.LinkedHashMap;  | 
  |  23 |     | 
   import java.util.Map;  | 
  |  24 |     | 
     | 
  |  25 |     | 
   import javax.servlet.Filter;  | 
  |  26 |     | 
   import javax.servlet.FilterChain;  | 
  |  27 |     | 
   import javax.servlet.FilterConfig;  | 
  |  28 |     | 
   import javax.servlet.ServletException;  | 
  |  29 |     | 
   import javax.servlet.ServletRequest;  | 
  |  30 |     | 
   import javax.servlet.ServletResponse;  | 
  |  31 |     | 
     | 
  |  32 |     | 
   import at.ipsquare.commons.core.interfaces.AbstractUnitOfWork;  | 
  |  33 |     | 
   import at.ipsquare.commons.core.interfaces.UnitOfWork;  | 
  |  34 |     | 
   import at.ipsquare.commons.core.util.Classes;  | 
  |  35 |     | 
   import at.ipsquare.commons.hibernate.HibernateRepository;  | 
  |  36 |     | 
   import at.ipsquare.commons.hibernate.HibernateRepositoryProvider;  | 
  |  37 |     | 
     | 
  |  38 |     | 
     | 
  |  39 |     | 
     | 
  |  40 |     | 
     | 
  |  41 |     | 
     | 
  |  42 |     | 
     | 
  |  43 |     | 
     | 
  |  44 |     | 
     | 
  |  45 |     | 
     | 
  |  46 |     | 
     | 
  |  47 |     | 
     | 
  |  48 |    52 |    public final class HibernateUnitOfWorkFilter implements Filter  | 
  |  49 |     | 
   { | 
  |  50 |     | 
       private Map<String, HibernateRepository> repoMap;  | 
  |  51 |     | 
       private RequestMatcher requestMatcher;  | 
  |  52 |     | 
         | 
  |  53 |     | 
       @Override  | 
  |  54 |     | 
       public void destroy()  | 
  |  55 |     | 
       { | 
  |  56 |     | 
             | 
  |  57 |    4 |        }  | 
  |  58 |     | 
         | 
  |  59 |     | 
       @Override  | 
  |  60 |     | 
       public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException  | 
  |  61 |     | 
       { | 
  |  62 |    20 |            if(requestMatcher.matches(req))  | 
  |  63 |    16 |                recurseThroughRepos((repoMap != null ? repoMap.entrySet().iterator() : null), req, res, chain);  | 
  |  64 |     | 
           else  | 
  |  65 |    4 |                chain.doFilter(req, res);  | 
  |  66 |    8 |        }  | 
  |  67 |     | 
         | 
  |  68 |     | 
       private static void  recurseThroughRepos(final Iterator<Map.Entry<String, HibernateRepository>> iter, final ServletRequest req, final ServletResponse res, final FilterChain chain) throws IOException, ServletException  | 
  |  69 |     | 
       { | 
  |  70 |    48 |            if(iter == null || !iter.hasNext())  | 
  |  71 |     | 
           { | 
  |  72 |    16 |                chain.doFilter(req, res);  | 
  |  73 |    4 |                return;  | 
  |  74 |     | 
           }  | 
  |  75 |     | 
             | 
  |  76 |    32 |            Map.Entry<String, HibernateRepository> entry = iter.next();  | 
  |  77 |    32 |            final String name = entry.getKey();  | 
  |  78 |    32 |            HibernateRepository repo = entry.getValue();  | 
  |  79 |     | 
             | 
  |  80 |    32 |            repo.executeUnitOfWork(new AbstractUnitOfWork<Void>()  | 
  |  81 |    64 |            { | 
  |  82 |     | 
               @Override  | 
  |  83 |     | 
               public Void execute() throws Exception  | 
  |  84 |     | 
               { | 
  |  85 |    32 |                    recurseThroughRepos(iter, req, res, chain);  | 
  |  86 |    8 |                    return null;  | 
  |  87 |     | 
               }  | 
  |  88 |     | 
                 | 
  |  89 |     | 
               @Override  | 
  |  90 |     | 
               public String getName()  | 
  |  91 |     | 
               { | 
  |  92 |    0 |                    return name;  | 
  |  93 |     | 
               }  | 
  |  94 |     | 
           });  | 
  |  95 |    8 |        }  | 
  |  96 |     | 
     | 
  |  97 |     | 
       @Override  | 
  |  98 |     | 
       public void init(FilterConfig cfg) throws ServletException  | 
  |  99 |     | 
       { | 
  |  100 |    20 |            String includePattern = null;  | 
  |  101 |    20 |            String excludePattern = null;  | 
  |  102 |     | 
             | 
  |  103 |    20 |            Map<String, HibernateRepository> newRepoMap = new LinkedHashMap<String, HibernateRepository>();  | 
  |  104 |    20 |            Enumeration<?> paramNames = cfg.getInitParameterNames();  | 
  |  105 |    20 |            if(paramNames != null)  | 
  |  106 |     | 
           { | 
  |  107 |    72 |                while(paramNames.hasMoreElements())  | 
  |  108 |     | 
               { | 
  |  109 |    56 |                    Object elem = paramNames.nextElement();  | 
  |  110 |    56 |                    if(elem != null)  | 
  |  111 |     | 
                   { | 
  |  112 |    56 |                        String name = elem.toString();  | 
  |  113 |    56 |                        String value = cfg.getInitParameter(name);  | 
  |  114 |     | 
                         | 
  |  115 |    56 |                        if(InitParameterNames.INCLUDE_PATH_PATTERN.equals(name))  | 
  |  116 |    12 |                            includePattern = value;  | 
  |  117 |    44 |                        else if(InitParameterNames.EXCLUDE_PATH_PATTERN.equals(name))  | 
  |  118 |    12 |                            excludePattern = value;  | 
  |  119 |    32 |                        else if(value != null)  | 
  |  120 |     | 
                       { | 
  |  121 |    32 |                          HibernateRepository repo = loadHibernateRepository(value);  | 
  |  122 |    28 |                          newRepoMap.put(name, repo);  | 
  |  123 |     | 
                       }  | 
  |  124 |     | 
                   }  | 
  |  125 |    52 |                }  | 
  |  126 |     | 
           }  | 
  |  127 |     | 
             | 
  |  128 |    16 |            if(includePattern != null || excludePattern != null)  | 
  |  129 |    12 |                requestMatcher = new PathPatternRequestMatcher(includePattern, excludePattern);  | 
  |  130 |     | 
           else  | 
  |  131 |    4 |                requestMatcher = TrivialRequestMatcher.ANYTHING;  | 
  |  132 |     | 
             | 
  |  133 |    16 |            checkForIdenticalRepos(newRepoMap);  | 
  |  134 |    12 |            repoMap = newRepoMap;  | 
  |  135 |    12 |        }  | 
  |  136 |     | 
         | 
  |  137 |     | 
       private static void checkForIdenticalRepos(Map<String, HibernateRepository> repoMap)  | 
  |  138 |     | 
       { | 
  |  139 |    16 |            Map<HibernateRepository, String> reverseRepoMap = new IdentityHashMap<HibernateRepository, String>();  | 
  |  140 |    16 |            for(Map.Entry<String, HibernateRepository> entry : repoMap.entrySet())  | 
  |  141 |     | 
           { | 
  |  142 |    24 |              String oldKey = reverseRepoMap.put(entry.getValue(), entry.getKey());  | 
  |  143 |    24 |              if(oldKey != null)  | 
  |  144 |     | 
             { | 
  |  145 |    4 |                throw new ServletConfigurationError(  | 
  |  146 |     | 
                   "Attempting to register the identical repositories ('" + entry.getValue() + "') with different names ('" + oldKey + "', '" + entry.getKey() + "')."); | 
  |  147 |     | 
             }  | 
  |  148 |    20 |            }  | 
  |  149 |    12 |        }  | 
  |  150 |     | 
     | 
  |  151 |     | 
       private static HibernateRepositoryProvider loadProvider(String className)  | 
  |  152 |     | 
       { | 
  |  153 |    32 |            Class<?> clazz = loadClass(className);  | 
  |  154 |    28 |            if(!HibernateRepositoryProvider.class.isAssignableFrom(clazz))  | 
  |  155 |    0 |                throw new ServletConfigurationError("'" + clazz.getCanonicalName() + "' does not implement '" + HibernateRepositoryProvider.class.getCanonicalName() + "'."); | 
  |  156 |     | 
     | 
  |  157 |     | 
           try  | 
  |  158 |     | 
           { | 
  |  159 |    28 |                HibernateRepositoryProvider provider = (HibernateRepositoryProvider) clazz.newInstance();  | 
  |  160 |    28 |                return provider;  | 
  |  161 |     | 
           }  | 
  |  162 |    0 |            catch(InstantiationException e)  | 
  |  163 |     | 
           { | 
  |  164 |    0 |                throw new ServletConfigurationError(unableToInstantineErrorString(clazz), e);  | 
  |  165 |     | 
           }  | 
  |  166 |    0 |            catch(IllegalAccessException e)  | 
  |  167 |     | 
           { | 
  |  168 |    0 |              throw new ServletConfigurationError(unableToInstantineErrorString(clazz), e);  | 
  |  169 |     | 
           }  | 
  |  170 |     | 
       }  | 
  |  171 |     | 
         | 
  |  172 |     | 
       private static HibernateRepository loadHibernateRepository(String providerClassName)  | 
  |  173 |     | 
       { | 
  |  174 |    32 |          HibernateRepositoryProvider provider = loadProvider(providerClassName);  | 
  |  175 |    28 |          HibernateRepository provided = provider.get();  | 
  |  176 |     | 
           | 
  |  177 |    28 |          if(provided == null)  | 
  |  178 |     | 
         { | 
  |  179 |    0 |            throw new ServletConfigurationError(  | 
  |  180 |     | 
               "Expected '"  + provider.getClass().getCanonicalName() + ".get()' to return an instance of HibernateRepository but got null.");  | 
  |  181 |     | 
         }  | 
  |  182 |     | 
           | 
  |  183 |    28 |          return provided;  | 
  |  184 |     | 
       }  | 
  |  185 |     | 
         | 
  |  186 |     | 
       private static String unableToInstantineErrorString(Class<?> clazz)  | 
  |  187 |     | 
       { | 
  |  188 |    0 |            return "Unable to instantine '" + clazz.getCanonicalName() + "'.";  | 
  |  189 |     | 
       }  | 
  |  190 |     | 
         | 
  |  191 |     | 
       private static Class<?> loadClass(String className)  | 
  |  192 |     | 
       { | 
  |  193 |     | 
           try  | 
  |  194 |     | 
           { | 
  |  195 |    32 |                return Classes.forName(className);  | 
  |  196 |     | 
           }  | 
  |  197 |    4 |            catch(ClassNotFoundException e1)  | 
  |  198 |     | 
           { | 
  |  199 |    4 |                throw new ServletConfigurationError("Could not load class '" + className + "'."); | 
  |  200 |     | 
           }  | 
  |  201 |     | 
       }  | 
  |  202 |     | 
   }  |