1   /*
2    * Copyright 2018 LINE Corporation
3    *
4    * LINE Corporation licenses this file to you under the Apache License,
5    * version 2.0 (the "License"); you may not use this file except in compliance
6    * with the License. You may obtain a copy of the License at:
7    *
8    *   https://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13   * License for the specific language governing permissions and limitations
14   * under the License.
15   */
16  package com.linecorp.centraldogma.server.auth.saml;
17  
18  import static com.google.common.base.Preconditions.checkArgument;
19  import static com.linecorp.centraldogma.server.auth.saml.HtmlUtil.getHtmlWithOnload;
20  import static java.util.Objects.requireNonNull;
21  
22  import java.io.UnsupportedEncodingException;
23  import java.net.URLEncoder;
24  import java.time.Duration;
25  import java.util.List;
26  import java.util.Optional;
27  import java.util.concurrent.CompletableFuture;
28  import java.util.concurrent.CompletionStage;
29  import java.util.function.Function;
30  import java.util.function.Supplier;
31  
32  import javax.annotation.Nullable;
33  
34  import org.opensaml.core.xml.XMLObject;
35  import org.opensaml.core.xml.schema.XSString;
36  import org.opensaml.messaging.context.MessageContext;
37  import org.opensaml.saml.common.messaging.context.SAMLBindingContext;
38  import org.opensaml.saml.saml2.core.AuthnRequest;
39  import org.opensaml.saml.saml2.core.NameIDType;
40  import org.opensaml.saml.saml2.core.Response;
41  
42  import com.google.common.base.Strings;
43  
44  import com.linecorp.armeria.common.AggregatedHttpRequest;
45  import com.linecorp.armeria.common.HttpRequest;
46  import com.linecorp.armeria.common.HttpResponse;
47  import com.linecorp.armeria.common.HttpStatus;
48  import com.linecorp.armeria.common.MediaType;
49  import com.linecorp.armeria.server.ServiceRequestContext;
50  import com.linecorp.armeria.server.saml.InvalidSamlRequestException;
51  import com.linecorp.armeria.server.saml.SamlBindingProtocol;
52  import com.linecorp.armeria.server.saml.SamlIdentityProviderConfig;
53  import com.linecorp.armeria.server.saml.SamlSingleSignOnHandler;
54  import com.linecorp.centraldogma.server.auth.Session;
55  import com.linecorp.centraldogma.server.internal.api.HttpApiUtil;
56  
57  import io.netty.handler.codec.http.QueryStringDecoder;
58  
59  /**
60   * A {@link SamlSingleSignOnHandler} implementation for the Central Dogma server.
61   */
62  final class SamlAuthSsoHandler implements SamlSingleSignOnHandler {
63  
64      private final Supplier<String> sessionIdGenerator;
65      private final Function<Session, CompletableFuture<Void>> loginSessionPropagator;
66      private final Duration sessionValidDuration;
67      private final Function<String, String> loginNameNormalizer;
68      @Nullable
69      private final String subjectLoginNameIdFormat;
70      @Nullable
71      private final String attributeLoginName;
72  
73      SamlAuthSsoHandler(
74              Supplier<String> sessionIdGenerator,
75              Function<Session, CompletableFuture<Void>> loginSessionPropagator,
76              Duration sessionValidDuration, Function<String, String> loginNameNormalizer,
77              @Nullable String subjectLoginNameIdFormat, @Nullable String attributeLoginName) {
78          this.sessionIdGenerator = requireNonNull(sessionIdGenerator, "sessionIdGenerator");
79          this.loginSessionPropagator = requireNonNull(loginSessionPropagator, "loginSessionPropagator");
80          this.sessionValidDuration = requireNonNull(sessionValidDuration, "sessionValidDuration");
81          this.loginNameNormalizer = requireNonNull(loginNameNormalizer, "loginNameNormalizer");
82          checkArgument(!Strings.isNullOrEmpty(subjectLoginNameIdFormat) ||
83                        !Strings.isNullOrEmpty(attributeLoginName),
84                        "a name ID format of a subject or an attribute name should be specified " +
85                        "for finding a login name");
86          this.subjectLoginNameIdFormat = subjectLoginNameIdFormat;
87          this.attributeLoginName = attributeLoginName;
88      }
89  
90      @Override
91      public CompletionStage<Void> beforeInitiatingSso(ServiceRequestContext ctx, HttpRequest req,
92                                                       MessageContext<AuthnRequest> message,
93                                                       SamlIdentityProviderConfig idpConfig) {
94          final QueryStringDecoder decoder = new QueryStringDecoder(req.path(), true);
95          final List<String> ref = decoder.parameters().get("ref");
96          if (ref == null || ref.isEmpty()) {
97              return CompletableFuture.completedFuture(null);
98          }
99  
100         final String relayState = ref.get(0);
101         if (idpConfig.ssoEndpoint().bindingProtocol() == SamlBindingProtocol.HTTP_REDIRECT &&
102             relayState.length() > 80) {
103             return CompletableFuture.completedFuture(null);
104         }
105 
106         final SAMLBindingContext sub = message.getSubcontext(SAMLBindingContext.class, true);
107         assert sub != null : SAMLBindingContext.class.getName();
108         sub.setRelayState(relayState);
109         return CompletableFuture.completedFuture(null);
110     }
111 
112     @Override
113     public HttpResponse loginSucceeded(ServiceRequestContext ctx, AggregatedHttpRequest req,
114                                        MessageContext<Response> message, @Nullable String sessionIndex,
115                                        @Nullable String relayState) {
116         final Response response = requireNonNull(message, "message").getMessage();
117         final String username = Optional.ofNullable(findLoginNameFromSubjects(response))
118                                         .orElseGet(() -> findLoginNameFromAttributes(response));
119         if (Strings.isNullOrEmpty(username)) {
120             return loginFailed(ctx, req, message,
121                                new IllegalStateException("Cannot get a username from the response"));
122         }
123 
124         final String sessionId = sessionIdGenerator.get();
125         final Session session =
126                 new Session(sessionId, loginNameNormalizer.apply(username), sessionValidDuration);
127 
128         final String redirectionScript;
129         if (!Strings.isNullOrEmpty(relayState)) {
130             try {
131                 redirectionScript = "window.location.href='/#" + URLEncoder.encode(relayState, "UTF-8") + '\'';
132             } catch (UnsupportedEncodingException e) {
133                 // Should never reach here.
134                 throw new Error();
135             }
136         } else {
137             redirectionScript = "window.location.href='/'";
138         }
139         return HttpResponse.of(loginSessionPropagator.apply(session).thenApply(
140                 unused -> HttpResponse.of(HttpStatus.OK, MediaType.HTML_UTF_8, getHtmlWithOnload(
141                         "localStorage.setItem('sessionId','" + sessionId + "')",
142                         redirectionScript))));
143     }
144 
145     @Nullable
146     private String findLoginNameFromSubjects(Response response) {
147         if (Strings.isNullOrEmpty(subjectLoginNameIdFormat)) {
148             return null;
149         }
150         return response.getAssertions()
151                        .stream()
152                        .map(s -> s.getSubject().getNameID())
153                        .filter(nameId -> nameId.getFormat().equals(subjectLoginNameIdFormat))
154                        .map(NameIDType::getValue)
155                        .findFirst()
156                        .orElse(null);
157     }
158 
159     @Nullable
160     private String findLoginNameFromAttributes(Response response) {
161         if (Strings.isNullOrEmpty(attributeLoginName)) {
162             return null;
163         }
164         return response.getAssertions()
165                        .stream()
166                        .flatMap(s -> s.getAttributeStatements().stream())
167                        .flatMap(s -> s.getAttributes().stream())
168                        .filter(attr -> attr.getName().equals(attributeLoginName))
169                        .findFirst()
170                        .map(attr -> {
171                            final XMLObject v = attr.getAttributeValues().get(0);
172                            if (v instanceof XSString) {
173                                return ((XSString) v).getValue();
174                            } else {
175                                return null;
176                            }
177                        })
178                        .orElse(null);
179     }
180 
181     @Override
182     public HttpResponse loginFailed(ServiceRequestContext ctx, AggregatedHttpRequest req,
183                                     @Nullable MessageContext<Response> message, Throwable cause) {
184         final HttpStatus status =
185                 cause instanceof InvalidSamlRequestException ? HttpStatus.BAD_REQUEST
186                                                              : HttpStatus.INTERNAL_SERVER_ERROR;
187         return HttpApiUtil.newResponse(ctx, status, cause);
188     }
189 }