You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

130 lines
2.9 KiB

package lib
import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strings"
)
var ErrPath = errors.New("invalid path")
type BasicAuthTransport struct {
Kind string
Username string
Password string
TrimPrefix string
NewPrefix string
URL string
Region string
DisableSSL bool
http.Transport
}
func (t *BasicAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if strings.Contains(req.URL.Path, "../") {
return nil, ErrPath
}
req.URL.Path = t.NewPrefix + strings.TrimPrefix(req.URL.Path, t.TrimPrefix)
if t.Username != "" {
switch t.Kind {
case "s3":
// todo make sign for s3
//signer := v4.NewSigner(credentials.NewStaticCredentials(t.Username, t.Password, ""))
//_, err := signer.Sign(req, nil, t.URL, t.Region, time.Now())
//if err != nil {
// return nil, err
//}
//
//fmt.Println(req.Header)
//awsReq := request.Request{
// Config: aws.Config{
// CredentialsChainVerboseErrors: nil,
// Credentials: credentials.NewStaticCredentials(t.Username, t.Password, ""),
// Endpoint: aws.String(t.URL),
// Region: aws.String(t.Region),
// DisableSSL: aws.Bool(t.DisableSSL),
// S3ForcePathStyle: aws.Bool(true),
// },
// Time: time.Now(),
// HTTPRequest: req,
//}
//
//s3.Sign(&awsReq)
//fmt.Println(awsReq.HTTPRequest.Header)
default:
req.Header.Set("Authorization", fmt.Sprintf("Basic %s",
base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s",
t.Username, t.Password)))))
}
}
return t.Transport.RoundTrip(req)
}
func (v *vfs) Proxy(trimPrefix, newPrefix string) (http.Handler, error) {
parsedUrl, err := url.Parse(v.endpoint)
if err != nil {
return nil, err
}
proxy := httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetXForwarded()
r.SetURL(parsedUrl)
},
ModifyResponse: func(resp *http.Response) error {
resp.Header.Del("Server")
for k := range resp.Header {
if strings.HasPrefix(k, "X-Amz-") {
resp.Header.Del(k)
}
}
if resp.StatusCode == http.StatusNotFound {
resp.Body = io.NopCloser(bytes.NewReader(nil))
resp.Header.Del("Content-Type")
resp.Header.Set("Content-Length", "0")
resp.ContentLength = 0
}
return nil
},
}
transport := BasicAuthTransport{
Kind: v.kind,
Username: v.accessKeyID,
Password: v.secretKey,
TrimPrefix: trimPrefix,
NewPrefix: newPrefix,
URL: v.endpoint,
Region: v.region,
DisableSSL: v.cacert == "",
}
if v.cacert != "" {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM([]byte(v.cacert))
transport.TLSClientConfig = &tls.Config{
RootCAs: caCertPool,
}
}
proxy.Transport = &transport
return &proxy, nil
}