Compare commits
44 Commits
2ec1738721
...
simple_cli
Author | SHA1 | Date | |
---|---|---|---|
5ac529ce26
|
|||
fa5d1e2689
|
|||
04c83cccb9
|
|||
8e4a336510
|
|||
cb28c07ff4
|
|||
d5db656ca2
|
|||
369d445637
|
|||
5709bfd21d
|
|||
606e85d467
|
|||
72c0188071
|
|||
028c084cdd
|
|||
f2b046056b
|
|||
985ed9943a
|
|||
f8a550883d
|
|||
252b49ae6a
|
|||
4b3d64c5cd
|
|||
799bf784aa
|
|||
14c78536de
|
|||
32bfd109b9
|
|||
a578beea0d
|
|||
3ac7e488af
|
|||
253f3dcdac
|
|||
b44d59bd21
|
|||
a3c1ae5615
|
|||
a601f1ceec
|
|||
ccc0a58f88
|
|||
824ca781d4
|
|||
cd4ebf9dc7
|
|||
71164ee85a
|
|||
841f7aa0de
|
|||
3a968df15b
|
|||
e8d8e8d70b
|
|||
25ee1d3299
|
|||
9d7ad260f2
|
|||
732fbacc61
|
|||
9478437262
|
|||
d8878eba09
|
|||
c55052ad5b
|
|||
a7466e5c77
|
|||
b86ee0dac4
|
|||
ec90717ad7
|
|||
3f417b0088
|
|||
9870b79854
|
|||
02643c1197
|
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
if [[ ! -d "/home/williamp/chatservice_concept" ]]; then
|
||||
echo "Cannot find source directory; Did you move it?"
|
||||
echo "(Looking for "/home/williamp/chatservice_concept")"
|
||||
echo 'Cannot force reload with this script - use "direnv reload" manually and then try again'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# rebuild the cache forcefully
|
||||
_nix_direnv_force_reload=1 direnv exec "/home/williamp/chatservice_concept" true
|
||||
|
||||
# Update the mtime for .envrc.
|
||||
# This will cause direnv to reload again - but without re-building.
|
||||
touch "/home/williamp/chatservice_concept/.envrc"
|
||||
|
||||
# Also update the timestamp of whatever profile_rc we have.
|
||||
# This makes sure that we know we are up to date.
|
||||
touch -r "/home/williamp/chatservice_concept/.envrc" "/home/williamp/chatservice_concept/.direnv"/*.rc
|
@@ -1 +0,0 @@
|
||||
/nix/store/gks035qaj52pl3ygwlicprsbqxw0wvja-source
|
@@ -1 +0,0 @@
|
||||
/nix/store/salp9r9j3pj9cwqf06wchs16hy8g882k-source
|
@@ -1 +0,0 @@
|
||||
/nix/store/78sjjah7cnj7zyhh9kq3yj1440rx0h56-nix-shell-env
|
File diff suppressed because it is too large
Load Diff
9
.gitignore
vendored
9
.gitignore
vendored
@@ -7,6 +7,7 @@
|
||||
*.dll
|
||||
*.so
|
||||
*.dylib
|
||||
__debug_bin*
|
||||
|
||||
# Test binary, built with `go test -c`
|
||||
*.test
|
||||
@@ -22,4 +23,10 @@ go.work
|
||||
go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
.env
|
||||
|
||||
# Direnv directory
|
||||
.direnv/
|
||||
|
||||
# Vscode directory
|
||||
.vscode/
|
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -1,15 +0,0 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Launch Package",
|
||||
"type": "go",
|
||||
"request": "launch",
|
||||
"mode": "auto",
|
||||
"program": "main.go"
|
||||
}
|
||||
]
|
||||
}
|
7
LICENSE
Normal file
7
LICENSE
Normal file
@@ -0,0 +1,7 @@
|
||||
Copyright 2025 William Peebles
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE
|
67
api/api.go
67
api/api.go
@@ -3,23 +3,49 @@ package api
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"git.dubyatp.xyz/chat-api-server/db"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/go-chi/docgen"
|
||||
"github.com/go-chi/render"
|
||||
)
|
||||
|
||||
var routes = flag.Bool("routes", false, "Generate API route documentation")
|
||||
|
||||
func RequestLog(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("api: request received",
|
||||
"request_uri", r.RequestURI,
|
||||
"source_ip", r.RemoteAddr,
|
||||
"user_agent", r.UserAgent())
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func Start() {
|
||||
|
||||
db.InitScyllaDB()
|
||||
defer db.CloseScyllaDB()
|
||||
|
||||
flag.Parse()
|
||||
|
||||
r := chi.NewRouter()
|
||||
|
||||
r.Use(cors.Handler(cors.Options{
|
||||
AllowedOrigins: []string{"http://localhost:5000"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 300, // Maximum value for preflight request cache
|
||||
}))
|
||||
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.Logger)
|
||||
r.Use(RequestLog)
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.URLFormat)
|
||||
r.Use(render.SetContentType(render.ContentTypeJSON))
|
||||
@@ -36,13 +62,50 @@ func Start() {
|
||||
panic("oh no")
|
||||
})
|
||||
|
||||
r.Route("/whoami", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
r.Use(LoginCtx)
|
||||
r.Get("/", Whoami)
|
||||
})
|
||||
|
||||
r.Route("/messages", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware) // Protect with authentication
|
||||
|
||||
r.Get("/", ListMessages)
|
||||
r.Route("/{messageID}", func(r chi.Router) {
|
||||
r.Use(MessageCtx) // Load message
|
||||
r.Get("/", GetMessage)
|
||||
r.Delete("/", DeleteMessage)
|
||||
r.Post("/edit", EditMessage)
|
||||
})
|
||||
r.Post("/new", NewMessage)
|
||||
r.Route("/new", func(r chi.Router) {
|
||||
r.Use(LoginCtx)
|
||||
r.Post("/", NewMessage)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/users", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware) // Protect with authentication
|
||||
|
||||
r.Get("/", ListUsers)
|
||||
r.Route("/{userID}", func(r chi.Router) {
|
||||
r.Use(UserCtx) // Load user
|
||||
r.Get("/", GetUser)
|
||||
})
|
||||
})
|
||||
|
||||
r.Route("/login", func(r chi.Router) {
|
||||
r.Post("/", Login)
|
||||
})
|
||||
|
||||
r.Route("/logout", func(r chi.Router) {
|
||||
r.Use(SessionAuthMiddleware)
|
||||
|
||||
r.Post("/", Logout)
|
||||
})
|
||||
|
||||
r.Route("/register", func(r chi.Router) {
|
||||
r.Post("/", NewUser)
|
||||
})
|
||||
|
||||
if *routes {
|
||||
|
226
api/auth.go
Normal file
226
api/auth.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
func hashToken(token string) string {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
func Login(w http.ResponseWriter, r *http.Request) {
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
if username == "" || password == "" {
|
||||
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUserByName(username)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
err = validatePassword(user.Password, password)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := CreateSession(user.ID)
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: sessionToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: false,
|
||||
})
|
||||
|
||||
slog.Info("auth: login successful", "userID", user.ID, "userName", user.Name)
|
||||
w.Write([]byte("Login successful"))
|
||||
}
|
||||
|
||||
func Logout(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
DeleteSession(sessionToken)
|
||||
|
||||
cookie.Expires = time.Now()
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
slog.Debug("auth: logout successful", "userID", user.ID, "userName", user.Name)
|
||||
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
|
||||
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
Token string
|
||||
UserID uuid.UUID
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
func CreateSession(userID uuid.UUID) string {
|
||||
|
||||
expiry := time.Now().Add(7 * 24 * time.Hour)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"userID": userID.String(),
|
||||
"exp": expiry.Unix(), // 7 day token expiry
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(jwtSecret)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to create JWT", "error", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
hashedToken := hashToken(tokenString)
|
||||
session := Session{
|
||||
Token: hashedToken,
|
||||
UserID: userID,
|
||||
Expiry: expiry,
|
||||
}
|
||||
dbAddSession(&session)
|
||||
|
||||
slog.Debug("auth: new session created", "userID", session.UserID)
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
|
||||
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return jwtSecret, nil
|
||||
})
|
||||
if err != nil || !token.Valid {
|
||||
slog.Debug("auth: session token invalid, rejecting")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
slog.Debug("auth: could not map claims from JWT")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userIDStr, ok := claims["userID"].(string)
|
||||
if !ok {
|
||||
slog.Debug("auth: userID claim is not a string")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to parse userID as uuid", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
session, err := dbGetSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Debug("auth: failed to retrieve session from db", "error", err)
|
||||
return uuid.Nil, false
|
||||
}
|
||||
if time.Now().After(session.Expiry) {
|
||||
slog.Debug("auth: session is expired (or otherwise invalid) in db")
|
||||
return uuid.Nil, false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session validated", "userID", session.UserID)
|
||||
return userID, true
|
||||
}
|
||||
|
||||
func DeleteSession(sessionToken string) bool {
|
||||
|
||||
hashedToken := hashToken(sessionToken)
|
||||
|
||||
err := dbDeleteSession(hashedToken)
|
||||
if err != nil {
|
||||
slog.Error("auth: failed to delete session", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
slog.Debug("auth: session deleted", "token", hashedToken)
|
||||
return true
|
||||
}
|
||||
|
||||
type contextKey string
|
||||
|
||||
const userIDKey contextKey = "userID"
|
||||
|
||||
func SessionAuthMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken := cookie.Value
|
||||
userID, valid := ValidateSession(sessionToken)
|
||||
if !valid {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Add username to request context
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (u *UserPayload) Render(w http.ResponseWriter, r *http.Request) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func hashPassword(password string) (string, error) {
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
return string(hashedPassword), err
|
||||
}
|
||||
|
||||
func validatePassword(hashedPassword, password string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
|
||||
}
|
274
api/db.go
274
api/db.go
@@ -3,97 +3,239 @@ package api
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"log/slog"
|
||||
|
||||
"git.dubyatp.xyz/chat-api-server/db"
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
func dbGetUser(id int64) (*User, error) {
|
||||
data := db.ExecDB("users")
|
||||
if data == nil {
|
||||
return nil, errors.New("failed to load users database")
|
||||
func dbGetUser(id string) (*User, error) {
|
||||
query := `SELECT id, name, password FROM users WHERE id = ?`
|
||||
var user User
|
||||
err := db.Session.Query(query, id).Scan(&user.ID, &user.Name, &user.Password)
|
||||
|
||||
if err == gocql.ErrNotFound {
|
||||
slog.Debug("db: user not found", "userid", id)
|
||||
return nil, errors.New("User not found")
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query user", "error", err)
|
||||
return nil, fmt.Errorf("failed to query user")
|
||||
}
|
||||
|
||||
users := data["users"].([]interface{})
|
||||
for _, u := range users {
|
||||
user := u.(map[string]interface{})
|
||||
if int64(user["ID"].(float64)) == id {
|
||||
return &User{
|
||||
ID: int64(user["ID"].(float64)),
|
||||
Name: user["Name"].(string),
|
||||
}, nil
|
||||
}
|
||||
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func dbGetUserByName(username string) (*User, error) {
|
||||
query := `SELECT id, name, password FROM users WHERE name = ?`
|
||||
var user User
|
||||
err := db.Session.Query(query, username).Scan(&user.ID, &user.Name, &user.Password)
|
||||
if err == gocql.ErrNotFound {
|
||||
slog.Debug("db: user not found", "username", username)
|
||||
return nil, errors.New("User not found")
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query user", "error", err)
|
||||
return nil, fmt.Errorf("failed to query user")
|
||||
}
|
||||
return nil, errors.New("User not found")
|
||||
|
||||
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func dbGetAllUsers() ([]*User, error) {
|
||||
query := `SELECT id, name, password FROM users`
|
||||
iter := db.Session.Query(query).Iter()
|
||||
defer iter.Close()
|
||||
|
||||
var users []*User
|
||||
for {
|
||||
user := &User{}
|
||||
if !iter.Scan(&user.ID, &user.Name, &user.Password) {
|
||||
break
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
slog.Error("db: failed to iterate users", "error", err)
|
||||
return nil, fmt.Errorf("failed to iterate users")
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
slog.Debug("db: no users found")
|
||||
return nil, errors.New("no users found")
|
||||
}
|
||||
|
||||
slog.Debug("db: user list returned")
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func dbGetMessage(id string) (*Message, error) {
|
||||
data := db.ExecDB("messages")
|
||||
if data == nil {
|
||||
return nil, errors.New("failed to load messages database")
|
||||
query := `SELECT id, body, edited, timestamp, userid FROM messages WHERE id = ?`
|
||||
var message Message
|
||||
err := db.Session.Query(query, id).Scan(
|
||||
&message.ID,
|
||||
&message.Body,
|
||||
&message.Edited,
|
||||
&message.Timestamp,
|
||||
&message.UserID)
|
||||
if err == gocql.ErrNotFound {
|
||||
slog.Debug("db: message not found", "messageid", id)
|
||||
return nil, errors.New("Message not found")
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query message", "error", err)
|
||||
return nil, fmt.Errorf("failed to query message")
|
||||
}
|
||||
|
||||
messages := data["messages"].([]interface{})
|
||||
for _, m := range messages {
|
||||
message := m.(map[string]interface{})
|
||||
if message["ID"].(string) == id {
|
||||
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
|
||||
}
|
||||
return &Message{
|
||||
ID: message["ID"].(string),
|
||||
UserID: int64(message["UserID"].(float64)),
|
||||
Body: message["Body"].(string),
|
||||
Timestamp: timestamp,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Message not found")
|
||||
slog.Debug("db: message found", "messageid", message.ID)
|
||||
return &message, nil
|
||||
}
|
||||
|
||||
func dbGetAllMessages() ([]*Message, error) {
|
||||
data := db.ExecDB("messages")
|
||||
//println(data)
|
||||
if data == nil {
|
||||
return nil, errors.New("failed to load messages database")
|
||||
query := `SELECT id, body, edited, timestamp, userid FROM messages`
|
||||
iter := db.Session.Query(query).Iter()
|
||||
defer iter.Close()
|
||||
|
||||
var messages []*Message
|
||||
for {
|
||||
message := &Message{}
|
||||
if !iter.Scan(
|
||||
&message.ID,
|
||||
&message.Body,
|
||||
&message.Edited,
|
||||
&message.Timestamp,
|
||||
&message.UserID) {
|
||||
break
|
||||
}
|
||||
messages = append(messages, message)
|
||||
}
|
||||
|
||||
messages := data["messages"].([]interface{})
|
||||
var result []*Message
|
||||
for _, m := range messages {
|
||||
message := m.(map[string]interface{})
|
||||
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
|
||||
}
|
||||
result = append(result, &Message{
|
||||
ID: message["ID"].(string),
|
||||
UserID: int64(message["UserID"].(float64)),
|
||||
Body: message["Body"].(string),
|
||||
Timestamp: timestamp,
|
||||
})
|
||||
if err := iter.Close(); err != nil {
|
||||
slog.Error("db: failed to iterate messages", "error", err)
|
||||
return nil, fmt.Errorf("failed to iterate messages")
|
||||
}
|
||||
if len(result) == 0 {
|
||||
|
||||
if len(messages) == 0 {
|
||||
slog.Debug("db: no messages found")
|
||||
return nil, errors.New("no messages found")
|
||||
}
|
||||
return result, nil
|
||||
|
||||
slog.Debug("db: message list returned")
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func dbAddUser(id int64, name string) error {
|
||||
user := map[string]interface{}{
|
||||
"ID": float64(id), // JSON numbers are float64 by default
|
||||
"Name": name,
|
||||
func dbAddSession(session *Session) error {
|
||||
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES (?, ?, ?)`
|
||||
err := db.Session.Query(query, session.Token, session.UserID, session.Expiry).Exec()
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add session", "error", err)
|
||||
return fmt.Errorf("failed to add session")
|
||||
}
|
||||
return db.AddUser(user)
|
||||
|
||||
slog.Debug("db: session added", "userID", session.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbGetSession(jwtToken string) (*Session, error) {
|
||||
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = ?`
|
||||
var session Session
|
||||
err := db.Session.Query(query, jwtToken).Scan(
|
||||
&session.Token,
|
||||
&session.UserID,
|
||||
&session.Expiry)
|
||||
if err == gocql.ErrNotFound {
|
||||
slog.Debug("db: session not found")
|
||||
return nil, errors.New("Session not found")
|
||||
} else if err != nil {
|
||||
slog.Error("db: failed to query session", "error", err)
|
||||
return nil, fmt.Errorf("failed to query session")
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
func dbDeleteSession(jwtToken string) error {
|
||||
query := `DELETE FROM sessions WHERE jwttoken = ?`
|
||||
|
||||
err := db.Session.Query(query, jwtToken).Exec()
|
||||
|
||||
if err != nil {
|
||||
slog.Error("db: failed to delete session")
|
||||
return fmt.Errorf("failed to delete session")
|
||||
}
|
||||
|
||||
slog.Debug("db: session deleted")
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddUser(user *User) error {
|
||||
query := `INSERT INTO users (id, name, password) VALUES (?, ?, ?)`
|
||||
err := db.Session.Query(query, user.ID, user.Name, user.Password).Exec()
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add user", "error", err, "userid", user.ID, "username", user.Name)
|
||||
return fmt.Errorf("failed to add user")
|
||||
}
|
||||
|
||||
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbAddMessage(message *Message) error {
|
||||
dbMessage := map[string]interface{}{
|
||||
"ID": message.ID,
|
||||
"UserID": message.UserID, // JSON numbers are float64
|
||||
"Body": message.Body,
|
||||
"Timestamp": message.Timestamp,
|
||||
query := `INSERT INTO messages (id, body, edited, timestamp, userid)
|
||||
VALUES (?, ?, ?, ?, ?)`
|
||||
err := db.Session.Query(query,
|
||||
message.ID,
|
||||
message.Body,
|
||||
nil,
|
||||
message.Timestamp,
|
||||
message.UserID).Exec()
|
||||
if err != nil {
|
||||
slog.Error("db: failed to add message", "error", err, "messageid", message.ID)
|
||||
return fmt.Errorf("failed to add message")
|
||||
}
|
||||
return db.AddMessage(dbMessage)
|
||||
|
||||
slog.Debug("db: message added", "messageid", message.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbUpdateMessage(updatedMessage *Message) error {
|
||||
var edited interface{}
|
||||
if updatedMessage.Edited.IsZero() {
|
||||
edited = nil
|
||||
} else {
|
||||
edited = updatedMessage.Edited
|
||||
}
|
||||
|
||||
query := `UPDATE messages
|
||||
SET body = ?, edited = ?, timestamp = ?
|
||||
WHERE ID = ?`
|
||||
|
||||
err := db.Session.Query(query,
|
||||
updatedMessage.Body,
|
||||
edited,
|
||||
updatedMessage.Timestamp,
|
||||
updatedMessage.ID).Exec()
|
||||
|
||||
if err != nil {
|
||||
slog.Error("db: failed to update message", "error", err, "messageid", updatedMessage.ID)
|
||||
return fmt.Errorf("failed to update message")
|
||||
}
|
||||
|
||||
slog.Debug("db: message updated", "messageid", updatedMessage.ID)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func dbDeleteMessage(id string) error {
|
||||
query := `DELETE FROM messages WHERE ID = ?`
|
||||
|
||||
err := db.Session.Query(query, id).Exec()
|
||||
|
||||
if err != nil {
|
||||
slog.Error("db: failed to delete message", "error", err, "messageid", id)
|
||||
return fmt.Errorf("failed to delete message")
|
||||
}
|
||||
|
||||
slog.Debug("db: message deleted", "messageid", id)
|
||||
return nil
|
||||
}
|
||||
|
138
api/message.go
138
api/message.go
@@ -3,6 +3,7 @@ package api
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -15,86 +16,179 @@ import (
|
||||
|
||||
func MessageCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering MessageCtx middleware")
|
||||
var message *Message
|
||||
var err error
|
||||
|
||||
if messageID := chi.URLParam(r, "messageID"); messageID != "" {
|
||||
slog.Debug("message: fetching message", "messageID", messageID)
|
||||
message, err = dbGetMessage(messageID)
|
||||
} else {
|
||||
slog.Error("message: messageID not found in URL parameters")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("message: failed to fetch message", "messageID", chi.URLParam(r, "messageID"), "error", err)
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully fetched message", "messageID", message.ID)
|
||||
ctx := context.WithValue(r.Context(), messageKey{}, message)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func GetMessage(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering GetMessage handler")
|
||||
message, ok := r.Context().Value(messageKey{}).(*Message)
|
||||
if !ok || message == nil {
|
||||
slog.Error("message: message not found in context")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: rendering message", "messageID", message.ID)
|
||||
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
|
||||
slog.Error("message: failed to render message response", "messageID", message.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func EditMessage(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering EditMessage handler")
|
||||
message, ok := r.Context().Value(messageKey{}).(*Message)
|
||||
if !ok || message == nil {
|
||||
slog.Error("message: message not found in context")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
slog.Error("message: failed to parse multipart form", "error", err)
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
body := r.FormValue("body")
|
||||
if body == "" {
|
||||
slog.Error("message: message body is empty")
|
||||
http.Error(w, "Message body cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: updating message", "messageID", message.ID)
|
||||
message.Body = body
|
||||
editedTime := time.Now()
|
||||
message.Edited = &editedTime
|
||||
|
||||
err = dbUpdateMessage(message)
|
||||
if err != nil {
|
||||
slog.Error("message: failed to update message", "messageID", message.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully updated message", "messageID", message.ID)
|
||||
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
|
||||
slog.Error("message: failed to render updated message response", "messageID", message.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteMessage(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering DeleteMessage handler")
|
||||
message, ok := r.Context().Value(messageKey{}).(*Message)
|
||||
if !ok || message == nil {
|
||||
slog.Error("message: message not found in context")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: deleting message", "messageID", message.ID)
|
||||
err := dbDeleteMessage(message.ID.String())
|
||||
if err != nil {
|
||||
slog.Error("message: failed to delete message", "messageID", message.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully deleted message", "messageID", message.ID)
|
||||
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
|
||||
slog.Error("message: failed to render deleted message response", "messageID", message.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ListMessages(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("message: entering ListMessages handler")
|
||||
dbMessages, err := dbGetAllMessages()
|
||||
if err != nil {
|
||||
slog.Error("message: failed to fetch messages", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully fetched messages", "count", len(dbMessages))
|
||||
if err := render.RenderList(w, r, NewMessageListResponse(dbMessages)); err != nil {
|
||||
slog.Error("message: failed to render message list response", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func newMessageID() string {
|
||||
return "msg_" + uuid.New().String()
|
||||
func newMessageID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func NewMessage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var msg Message
|
||||
err := json.NewDecoder(r.Body).Decode(&msg)
|
||||
slog.Debug("message: entering NewMessage handler")
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid JSON", http.StatusBadRequest)
|
||||
slog.Error("message: failed to parse multipart form", "error", err)
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
msg.ID = newMessageID()
|
||||
var user = r.Context().Value(userKey{}).(*User)
|
||||
body := r.FormValue("body")
|
||||
|
||||
msg.Timestamp = time.Now()
|
||||
if body == "" {
|
||||
slog.Error("message: message body is empty")
|
||||
http.Error(w, "Invalid body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
msg := Message{
|
||||
ID: newMessageID(),
|
||||
UserID: user.ID,
|
||||
Body: body,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
|
||||
slog.Debug("message: creating new message", "messageID", msg.ID)
|
||||
err = dbAddMessage(&msg)
|
||||
if err != nil {
|
||||
slog.Error("message: failed to add new message", "messageID", msg.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("message: successfully created new message", "messageID", msg.ID)
|
||||
render.Render(w, r, NewMessageResponse(&msg))
|
||||
}
|
||||
|
||||
type messageKey struct{}
|
||||
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Body string `json:"body"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Body string `json:"body"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Edited *time.Time `json:"edited"`
|
||||
}
|
||||
|
||||
type MessageRequest struct {
|
||||
@@ -115,19 +209,27 @@ type MessageResponse struct {
|
||||
|
||||
func (m MessageResponse) MarshalJSON() ([]byte, error) {
|
||||
type OrderedMessageResponse struct {
|
||||
ID string `json:"id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Body string `json:"body"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
Edited *string `json:"edited,omitempty"` // Use a pointer to allow null values
|
||||
User *UserPayload `json:"user,omitempty"`
|
||||
Elapsed int64 `json:"elapsed"`
|
||||
}
|
||||
|
||||
var edited *string
|
||||
if m.Message.Edited != nil { // Check if Edited is not the zero value
|
||||
editedStr := m.Message.Edited.Format(time.RFC3339)
|
||||
edited = &editedStr
|
||||
}
|
||||
|
||||
ordered := OrderedMessageResponse{
|
||||
ID: m.Message.ID,
|
||||
UserID: m.Message.UserID,
|
||||
Body: m.Message.Body,
|
||||
Timestamp: m.Message.Timestamp.Format(time.RFC3339),
|
||||
Edited: edited, // Null if Edited is zero
|
||||
User: m.User,
|
||||
Elapsed: m.Elapsed,
|
||||
}
|
||||
|
@@ -10,7 +10,7 @@ func NewMessageResponse(message *Message) *MessageResponse {
|
||||
resp := &MessageResponse{Message: message}
|
||||
|
||||
if resp.User == nil {
|
||||
if user, _ := dbGetUser(resp.UserID); user != nil {
|
||||
if user, _ := dbGetUser(resp.UserID.String()); user != nil {
|
||||
resp.User = NewUserPayloadResponse(user)
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,14 @@ func NewMessageListResponse(messages []*Message) []render.Renderer {
|
||||
return list
|
||||
}
|
||||
|
||||
func NewUserListResponse(users []*User) []render.Renderer {
|
||||
list := []render.Renderer{}
|
||||
for _, user := range users {
|
||||
list = append(list, NewUserPayloadResponse(user))
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
func NewUserPayloadResponse(user *User) *UserPayload {
|
||||
return &UserPayload{User: user}
|
||||
}
|
||||
|
161
api/user.go
161
api/user.go
@@ -1,8 +1,165 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/render"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func UserCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering UserCtx middleware")
|
||||
var user *User
|
||||
var err error
|
||||
|
||||
if userID := chi.URLParam(r, "userID"); userID != "" {
|
||||
slog.Debug("user: fetching user", "userID", userID)
|
||||
user, err = dbGetUser(userID)
|
||||
} else {
|
||||
slog.Error("user: userID not found in URL parameters")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch user", "userID", user.ID, "error", err)
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched user", "userID", user.ID)
|
||||
ctx := context.WithValue(r.Context(), userKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func Whoami(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering Whoami handler")
|
||||
user, ok := r.Context().Value(userKey{}).(*User)
|
||||
if !ok || user == nil {
|
||||
slog.Debug("user: anonymous user")
|
||||
w.Write([]byte("anonymous"))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: returning user name", "userID", user.ID, "userName", user.Name)
|
||||
w.Write([]byte(user.Name))
|
||||
}
|
||||
|
||||
func LoginCtx(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering LoginCtx middleware")
|
||||
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
|
||||
if !ok || userID == uuid.Nil {
|
||||
slog.Debug("user: no user ID provided, assuming anonymous user")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: fetching user by user ID", "userID", userID)
|
||||
user, err := dbGetUser(userID.String())
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch user by user ID", "userID", userID, "error", err)
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched user", "userID", user.ID, "username", user.Name)
|
||||
ctx := context.WithValue(r.Context(), userKey{}, user)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering GetUser handler")
|
||||
user, ok := r.Context().Value(userKey{}).(*User)
|
||||
if !ok || user == nil {
|
||||
slog.Error("user: user not found in context")
|
||||
render.Render(w, r, ErrNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: rendering user", "userID", user.ID, "userName", user.Name)
|
||||
if err := render.Render(w, r, NewUserPayloadResponse(user)); err != nil {
|
||||
slog.Error("user: failed to render user response", "userID", user.ID, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func ListUsers(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering ListUsers handler")
|
||||
dbUsers, err := dbGetAllUsers()
|
||||
if err != nil {
|
||||
slog.Error("user: failed to fetch users", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully fetched users", "count", len(dbUsers))
|
||||
if err := render.RenderList(w, r, NewUserListResponse(dbUsers)); err != nil {
|
||||
slog.Error("user: failed to render user list response", "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func newUserID() uuid.UUID {
|
||||
return uuid.New()
|
||||
}
|
||||
|
||||
func NewUser(w http.ResponseWriter, r *http.Request) {
|
||||
slog.Debug("user: entering NewUser handler")
|
||||
err := r.ParseMultipartForm(64 << 10)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to parse multipart form", "error", err)
|
||||
http.Error(w, "Unable to parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
newUserName := r.FormValue("name")
|
||||
password := r.FormValue("password")
|
||||
if newUserName == "" || password == "" {
|
||||
slog.Error("user: username or password is empty")
|
||||
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: hashing password for new user", "userName", newUserName)
|
||||
hashedPassword, err := hashPassword(password)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to hash password", "error", err)
|
||||
http.Error(w, "Unable to hash password", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
newUser := User{
|
||||
ID: newUserID(),
|
||||
Name: newUserName,
|
||||
Password: hashedPassword,
|
||||
}
|
||||
|
||||
slog.Debug("user: adding new user to database", "userID", newUser.ID, "userName", newUser.Name)
|
||||
err = dbAddUser(&newUser)
|
||||
if err != nil {
|
||||
slog.Error("user: failed to add new user", "userID", newUser.ID, "userName", newUser.Name, "error", err)
|
||||
render.Render(w, r, ErrRender(err))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("user: successfully added new user", "userID", newUser.ID, "userName", newUser.Name)
|
||||
render.Render(w, r, NewUserPayloadResponse(&newUser))
|
||||
}
|
||||
|
||||
type userKey struct{}
|
||||
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"-"`
|
||||
}
|
||||
|
||||
type UserPayload struct {
|
||||
|
@@ -1,95 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
func ExecDB(db_name string) map[string]interface{} {
|
||||
var result map[string]interface{}
|
||||
|
||||
if db_name == "users" {
|
||||
users_db, err := os.Open("./test_data/users.json")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil
|
||||
}
|
||||
fmt.Println("Successfully opened Users DB")
|
||||
defer users_db.Close()
|
||||
|
||||
byteValue, _ := io.ReadAll(users_db)
|
||||
var users []interface{}
|
||||
json.Unmarshal(byteValue, &users)
|
||||
result = map[string]interface{}{"users": users}
|
||||
|
||||
} else if db_name == "messages" {
|
||||
messages_db, err := os.Open("./test_data/messages.json")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil
|
||||
}
|
||||
fmt.Println("Successfully opened Messages DB")
|
||||
defer messages_db.Close()
|
||||
|
||||
byteValue, _ := io.ReadAll(messages_db)
|
||||
var messages []interface{}
|
||||
json.Unmarshal(byteValue, &messages)
|
||||
result = map[string]interface{}{"messages": messages}
|
||||
|
||||
} else {
|
||||
fmt.Println("Invalid DB name")
|
||||
return nil
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func WriteDB(db_name string, data interface{}) error {
|
||||
var filePath string
|
||||
|
||||
switch db_name {
|
||||
case "users":
|
||||
filePath = "./test_data/users.json"
|
||||
case "messages":
|
||||
filePath = "./test_data/messages.json"
|
||||
default:
|
||||
return fmt.Errorf("invalid database name: %s", db_name)
|
||||
}
|
||||
|
||||
jsonData, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling data to JSON: %v", err)
|
||||
}
|
||||
|
||||
err = os.WriteFile(filePath, jsonData, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error writing to file: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Successfully wrote to %s DB\n", db_name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddUser(user map[string]interface{}) error {
|
||||
currentData := ExecDB("users")
|
||||
if currentData == nil {
|
||||
return fmt.Errorf("error reading users database")
|
||||
}
|
||||
|
||||
users := currentData["users"].([]interface{})
|
||||
users = append(users, user)
|
||||
return WriteDB("users", users)
|
||||
}
|
||||
|
||||
func AddMessage(message map[string]interface{}) error {
|
||||
currentData := ExecDB("messages")
|
||||
if currentData == nil {
|
||||
return fmt.Errorf("error reading messages database")
|
||||
}
|
||||
|
||||
messages := currentData["messages"].([]interface{})
|
||||
messages = append(messages, message)
|
||||
return WriteDB("messages", messages)
|
||||
}
|
30
db/scylla.go
Normal file
30
db/scylla.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
var Session *gocql.Session
|
||||
|
||||
func InitScyllaDB() {
|
||||
cluster := gocql.NewCluster(os.Getenv("SCYLLA_CLUSTER")) // Replace with your ScyllaDB cluster IPs
|
||||
cluster.Keyspace = os.Getenv("SCYLLA_KEYSPACE") // Replace with your keyspace
|
||||
cluster.Consistency = gocql.Quorum
|
||||
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
slog.Error("Failed to connect to ScyllaDB", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
Session = session
|
||||
slog.Info("Connected to ScyllaDB")
|
||||
}
|
||||
|
||||
func CloseScyllaDB() {
|
||||
if Session != nil {
|
||||
Session.Close()
|
||||
}
|
||||
}
|
137
docs/routes.md
137
docs/routes.md
@@ -7,10 +7,24 @@
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [Start.func1]()
|
||||
- _GET_
|
||||
- [Start.func1]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/login`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/login**
|
||||
- **/**
|
||||
- _POST_
|
||||
- [Login]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
@@ -20,11 +34,27 @@
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/messages**
|
||||
- **/**
|
||||
- _GET_
|
||||
- [ListMessages]()
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [ListMessages]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/messages/new`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/messages**
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/new**
|
||||
- _POST_
|
||||
- [NewMessage]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
@@ -34,13 +64,33 @@
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/messages**
|
||||
- **/{messageID}**
|
||||
- [MessageCtx]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [GetMessage]()
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/{messageID}**
|
||||
- [MessageCtx]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [GetMessage]()
|
||||
- _DELETE_
|
||||
- [DeleteMessage]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/messages/{messageID}/edit`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/messages**
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/{messageID}**
|
||||
- [MessageCtx]()
|
||||
- **/edit**
|
||||
- _POST_
|
||||
- [EditMessage]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
@@ -50,10 +100,10 @@
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/panic**
|
||||
- _GET_
|
||||
- [Start.func3]()
|
||||
- _GET_
|
||||
- [Start.func3]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
@@ -63,11 +113,58 @@
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/ping**
|
||||
- _GET_
|
||||
- [Start.func2]()
|
||||
- _GET_
|
||||
- [Start.func2]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/register`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/register**
|
||||
- **/**
|
||||
- _POST_
|
||||
- [NewUser]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/users`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/users**
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [ListUsers]()
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary>`/users/{userID}`</summary>
|
||||
|
||||
- [RequestID]()
|
||||
- [Logger]()
|
||||
- [Recoverer]()
|
||||
- [URLFormat]()
|
||||
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
|
||||
- **/users**
|
||||
- [SessionAuthMiddleware]()
|
||||
- **/{userID}**
|
||||
- [UserCtx]()
|
||||
- **/**
|
||||
- _GET_
|
||||
- [GetUser]()
|
||||
|
||||
</details>
|
||||
|
||||
Total # of routes: 5
|
||||
Total # of routes: 11
|
||||
|
||||
|
4
example_client/svelte-client/.gitignore
vendored
Normal file
4
example_client/svelte-client/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
/node_modules/
|
||||
/public/build/
|
||||
|
||||
.DS_Store
|
107
example_client/svelte-client/README.md
Normal file
107
example_client/svelte-client/README.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# This repo is no longer maintained. Consider using `npm init vite` and selecting the `svelte` option or — if you want a full-fledged app framework — use [SvelteKit](https://kit.svelte.dev), the official application framework for Svelte.
|
||||
|
||||
---
|
||||
|
||||
# svelte app
|
||||
|
||||
This is a project template for [Svelte](https://svelte.dev) apps. It lives at https://github.com/sveltejs/template.
|
||||
|
||||
To create a new project based on this template using [degit](https://github.com/Rich-Harris/degit):
|
||||
|
||||
```bash
|
||||
npx degit sveltejs/template svelte-app
|
||||
cd svelte-app
|
||||
```
|
||||
|
||||
*Note that you will need to have [Node.js](https://nodejs.org) installed.*
|
||||
|
||||
|
||||
## Get started
|
||||
|
||||
Install the dependencies...
|
||||
|
||||
```bash
|
||||
cd svelte-app
|
||||
npm install
|
||||
```
|
||||
|
||||
...then start [Rollup](https://rollupjs.org):
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Navigate to [localhost:8080](http://localhost:8080). You should see your app running. Edit a component file in `src`, save it, and reload the page to see your changes.
|
||||
|
||||
By default, the server will only respond to requests from localhost. To allow connections from other computers, edit the `sirv` commands in package.json to include the option `--host 0.0.0.0`.
|
||||
|
||||
If you're using [Visual Studio Code](https://code.visualstudio.com/) we recommend installing the official extension [Svelte for VS Code](https://marketplace.visualstudio.com/items?itemName=svelte.svelte-vscode). If you are using other editors you may need to install a plugin in order to get syntax highlighting and intellisense.
|
||||
|
||||
## Building and running in production mode
|
||||
|
||||
To create an optimised version of the app:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
You can run the newly built app with `npm run start`. This uses [sirv](https://github.com/lukeed/sirv), which is included in your package.json's `dependencies` so that the app will work when you deploy to platforms like [Heroku](https://heroku.com).
|
||||
|
||||
|
||||
## Single-page app mode
|
||||
|
||||
By default, sirv will only respond to requests that match files in `public`. This is to maximise compatibility with static fileservers, allowing you to deploy your app anywhere.
|
||||
|
||||
If you're building a single-page app (SPA) with multiple routes, sirv needs to be able to respond to requests for *any* path. You can make it so by editing the `"start"` command in package.json:
|
||||
|
||||
```js
|
||||
"start": "sirv public --single"
|
||||
```
|
||||
|
||||
## Using TypeScript
|
||||
|
||||
This template comes with a script to set up a TypeScript development environment, you can run it immediately after cloning the template with:
|
||||
|
||||
```bash
|
||||
node scripts/setupTypeScript.js
|
||||
```
|
||||
|
||||
Or remove the script via:
|
||||
|
||||
```bash
|
||||
rm scripts/setupTypeScript.js
|
||||
```
|
||||
|
||||
If you want to use `baseUrl` or `path` aliases within your `tsconfig`, you need to set up `@rollup/plugin-alias` to tell Rollup to resolve the aliases. For more info, see [this StackOverflow question](https://stackoverflow.com/questions/63427935/setup-tsconfig-path-in-svelte).
|
||||
|
||||
## Deploying to the web
|
||||
|
||||
### With [Vercel](https://vercel.com)
|
||||
|
||||
Install `vercel` if you haven't already:
|
||||
|
||||
```bash
|
||||
npm install -g vercel
|
||||
```
|
||||
|
||||
Then, from within your project folder:
|
||||
|
||||
```bash
|
||||
cd public
|
||||
vercel deploy --name my-project
|
||||
```
|
||||
|
||||
### With [surge](https://surge.sh/)
|
||||
|
||||
Install `surge` if you haven't already:
|
||||
|
||||
```bash
|
||||
npm install -g surge
|
||||
```
|
||||
|
||||
Then, from within your project folder:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
surge public my-project.surge.sh
|
||||
```
|
1349
example_client/svelte-client/package-lock.json
generated
Normal file
1349
example_client/svelte-client/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
26
example_client/svelte-client/package.json
Normal file
26
example_client/svelte-client/package.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "svelte-app",
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"build": "rollup -c",
|
||||
"dev": "rollup -c -w",
|
||||
"start": "sirv public --single"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@rollup/plugin-commonjs": "^24.0.0",
|
||||
"@rollup/plugin-node-resolve": "^15.0.0",
|
||||
"@rollup/plugin-terser": "^0.4.0",
|
||||
"rollup": "^3.15.0",
|
||||
"rollup-plugin-css-only": "^4.3.0",
|
||||
"rollup-plugin-livereload": "^2.0.0",
|
||||
"rollup-plugin-svelte": "^7.1.2",
|
||||
"svelte": "^3.55.0"
|
||||
},
|
||||
"dependencies": {
|
||||
"axios": "^1.9.0",
|
||||
"sirv-cli": "^2.0.0",
|
||||
"svelte-spa-router": "^4.0.1"
|
||||
}
|
||||
}
|
BIN
example_client/svelte-client/public/favicon.png
Normal file
BIN
example_client/svelte-client/public/favicon.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.1 KiB |
63
example_client/svelte-client/public/global.css
Normal file
63
example_client/svelte-client/public/global.css
Normal file
@@ -0,0 +1,63 @@
|
||||
html, body {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
color: #333;
|
||||
margin: 0;
|
||||
padding: 8px;
|
||||
box-sizing: border-box;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen-Sans, Ubuntu, Cantarell, "Helvetica Neue", sans-serif;
|
||||
}
|
||||
|
||||
a {
|
||||
color: rgb(0,100,200);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
a:visited {
|
||||
color: rgb(0,80,160);
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
}
|
||||
|
||||
input, button, select, textarea {
|
||||
font-family: inherit;
|
||||
font-size: inherit;
|
||||
-webkit-padding: 0.4em 0;
|
||||
padding: 0.4em;
|
||||
margin: 0 0 0.5em 0;
|
||||
box-sizing: border-box;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
input:disabled {
|
||||
color: #ccc;
|
||||
}
|
||||
|
||||
button {
|
||||
color: #333;
|
||||
background-color: #f4f4f4;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
button:disabled {
|
||||
color: #999;
|
||||
}
|
||||
|
||||
button:not(:disabled):active {
|
||||
background-color: #ddd;
|
||||
}
|
||||
|
||||
button:focus {
|
||||
border-color: #666;
|
||||
}
|
18
example_client/svelte-client/public/index.html
Normal file
18
example_client/svelte-client/public/index.html
Normal file
@@ -0,0 +1,18 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset='utf-8'>
|
||||
<meta name='viewport' content='width=device-width,initial-scale=1'>
|
||||
|
||||
<title>Svelte app</title>
|
||||
|
||||
<link rel='icon' type='image/png' href='/favicon.png'>
|
||||
<link rel='stylesheet' href='/global.css'>
|
||||
<link rel='stylesheet' href='/build/bundle.css'>
|
||||
|
||||
<script defer src='/build/bundle.js'></script>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
</body>
|
||||
</html>
|
78
example_client/svelte-client/rollup.config.js
Normal file
78
example_client/svelte-client/rollup.config.js
Normal file
@@ -0,0 +1,78 @@
|
||||
import { spawn } from 'child_process';
|
||||
import svelte from 'rollup-plugin-svelte';
|
||||
import commonjs from '@rollup/plugin-commonjs';
|
||||
import terser from '@rollup/plugin-terser';
|
||||
import resolve from '@rollup/plugin-node-resolve';
|
||||
import livereload from 'rollup-plugin-livereload';
|
||||
import css from 'rollup-plugin-css-only';
|
||||
|
||||
const production = !process.env.ROLLUP_WATCH;
|
||||
|
||||
function serve() {
|
||||
let server;
|
||||
|
||||
function toExit() {
|
||||
if (server) server.kill(0);
|
||||
}
|
||||
|
||||
return {
|
||||
writeBundle() {
|
||||
if (server) return;
|
||||
server = spawn('npm', ['run', 'start', '--', '--dev'], {
|
||||
stdio: ['ignore', 'inherit', 'inherit'],
|
||||
shell: true
|
||||
});
|
||||
|
||||
process.on('SIGTERM', toExit);
|
||||
process.on('exit', toExit);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
export default {
|
||||
input: 'src/main.js',
|
||||
output: {
|
||||
sourcemap: true,
|
||||
format: 'iife',
|
||||
name: 'app',
|
||||
file: 'public/build/bundle.js'
|
||||
},
|
||||
plugins: [
|
||||
svelte({
|
||||
compilerOptions: {
|
||||
// enable run-time checks when not in production
|
||||
dev: !production
|
||||
}
|
||||
}),
|
||||
// we'll extract any component CSS out into
|
||||
// a separate file - better for performance
|
||||
css({ output: 'bundle.css' }),
|
||||
|
||||
// If you have external dependencies installed from
|
||||
// npm, you'll most likely need these plugins. In
|
||||
// some cases you'll need additional configuration -
|
||||
// consult the documentation for details:
|
||||
// https://github.com/rollup/plugins/tree/master/packages/commonjs
|
||||
resolve({
|
||||
browser: true,
|
||||
dedupe: ['svelte'],
|
||||
exportConditions: ['svelte']
|
||||
}),
|
||||
commonjs(),
|
||||
|
||||
// In dev mode, call `npm run start` once
|
||||
// the bundle has been generated
|
||||
!production && serve(),
|
||||
|
||||
// Watch the `public` directory and refresh the
|
||||
// browser on changes when not in production
|
||||
!production && livereload('public'),
|
||||
|
||||
// If we're building for production (npm run build
|
||||
// instead of npm run dev), minify
|
||||
production && terser()
|
||||
],
|
||||
watch: {
|
||||
clearScreen: false
|
||||
}
|
||||
};
|
134
example_client/svelte-client/scripts/setupTypeScript.js
Normal file
134
example_client/svelte-client/scripts/setupTypeScript.js
Normal file
@@ -0,0 +1,134 @@
|
||||
// @ts-check
|
||||
|
||||
/** This script modifies the project to support TS code in .svelte files like:
|
||||
|
||||
<script lang="ts">
|
||||
export let name: string;
|
||||
</script>
|
||||
|
||||
As well as validating the code for CI.
|
||||
*/
|
||||
|
||||
/** To work on this script:
|
||||
rm -rf test-template template && git clone sveltejs/template test-template && node scripts/setupTypeScript.js test-template
|
||||
*/
|
||||
|
||||
import fs from "fs"
|
||||
import path from "path"
|
||||
import { argv } from "process"
|
||||
import url from 'url';
|
||||
|
||||
const __filename = url.fileURLToPath(import.meta.url);
|
||||
const __dirname = url.fileURLToPath(new URL('.', import.meta.url));
|
||||
const projectRoot = argv[2] || path.join(__dirname, "..")
|
||||
|
||||
// Add deps to pkg.json
|
||||
const packageJSON = JSON.parse(fs.readFileSync(path.join(projectRoot, "package.json"), "utf8"))
|
||||
packageJSON.devDependencies = Object.assign(packageJSON.devDependencies, {
|
||||
"svelte-check": "^3.0.0",
|
||||
"svelte-preprocess": "^5.0.0",
|
||||
"@rollup/plugin-typescript": "^11.0.0",
|
||||
"typescript": "^4.9.0",
|
||||
"tslib": "^2.5.0",
|
||||
"@tsconfig/svelte": "^3.0.0"
|
||||
})
|
||||
|
||||
// Add script for checking
|
||||
packageJSON.scripts = Object.assign(packageJSON.scripts, {
|
||||
"check": "svelte-check"
|
||||
})
|
||||
|
||||
// Write the package JSON
|
||||
fs.writeFileSync(path.join(projectRoot, "package.json"), JSON.stringify(packageJSON, null, " "))
|
||||
|
||||
// mv src/main.js to main.ts - note, we need to edit rollup.config.js for this too
|
||||
const beforeMainJSPath = path.join(projectRoot, "src", "main.js")
|
||||
const afterMainTSPath = path.join(projectRoot, "src", "main.ts")
|
||||
fs.renameSync(beforeMainJSPath, afterMainTSPath)
|
||||
|
||||
// Switch the app.svelte file to use TS
|
||||
const appSveltePath = path.join(projectRoot, "src", "App.svelte")
|
||||
let appFile = fs.readFileSync(appSveltePath, "utf8")
|
||||
appFile = appFile.replace("<script>", '<script lang="ts">')
|
||||
appFile = appFile.replace("export let name;", 'export let name: string;')
|
||||
fs.writeFileSync(appSveltePath, appFile)
|
||||
|
||||
// Edit rollup config
|
||||
const rollupConfigPath = path.join(projectRoot, "rollup.config.js")
|
||||
let rollupConfig = fs.readFileSync(rollupConfigPath, "utf8")
|
||||
|
||||
// Edit imports
|
||||
rollupConfig = rollupConfig.replace(`'rollup-plugin-css-only';`, `'rollup-plugin-css-only';
|
||||
import sveltePreprocess from 'svelte-preprocess';
|
||||
import typescript from '@rollup/plugin-typescript';`)
|
||||
|
||||
// Replace name of entry point
|
||||
rollupConfig = rollupConfig.replace(`'src/main.js'`, `'src/main.ts'`)
|
||||
|
||||
// Add preprocessor
|
||||
rollupConfig = rollupConfig.replace(
|
||||
'compilerOptions:',
|
||||
'preprocess: sveltePreprocess({ sourceMap: !production }),\n\t\t\tcompilerOptions:'
|
||||
);
|
||||
|
||||
// Add TypeScript
|
||||
rollupConfig = rollupConfig.replace(
|
||||
'commonjs(),',
|
||||
'commonjs(),\n\t\ttypescript({\n\t\t\tsourceMap: !production,\n\t\t\tinlineSources: !production\n\t\t}),'
|
||||
);
|
||||
fs.writeFileSync(rollupConfigPath, rollupConfig)
|
||||
|
||||
// Add svelte.config.js
|
||||
const tsconfig = `{
|
||||
"extends": "@tsconfig/svelte/tsconfig.json",
|
||||
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules/*", "__sapper__/*", "public/*"]
|
||||
}`
|
||||
const tsconfigPath = path.join(projectRoot, "tsconfig.json")
|
||||
fs.writeFileSync(tsconfigPath, tsconfig)
|
||||
|
||||
// Add TSConfig
|
||||
const svelteConfig = `import sveltePreprocess from 'svelte-preprocess';
|
||||
|
||||
export default {
|
||||
preprocess: sveltePreprocess()
|
||||
};
|
||||
`
|
||||
const svelteConfigPath = path.join(projectRoot, "svelte.config.js")
|
||||
fs.writeFileSync(svelteConfigPath, svelteConfig)
|
||||
|
||||
// Add global.d.ts
|
||||
const dtsPath = path.join(projectRoot, "src", "global.d.ts")
|
||||
fs.writeFileSync(dtsPath, `/// <reference types="svelte" />`)
|
||||
|
||||
// Delete this script, but not during testing
|
||||
if (!argv[2]) {
|
||||
// Remove the script
|
||||
fs.unlinkSync(path.join(__filename))
|
||||
|
||||
// Check for Mac's DS_store file, and if it's the only one left remove it
|
||||
const remainingFiles = fs.readdirSync(path.join(__dirname))
|
||||
if (remainingFiles.length === 1 && remainingFiles[0] === '.DS_store') {
|
||||
fs.unlinkSync(path.join(__dirname, '.DS_store'))
|
||||
}
|
||||
|
||||
// Check if the scripts folder is empty
|
||||
if (fs.readdirSync(path.join(__dirname)).length === 0) {
|
||||
// Remove the scripts folder
|
||||
fs.rmdirSync(path.join(__dirname))
|
||||
}
|
||||
}
|
||||
|
||||
// Adds the extension recommendation
|
||||
fs.mkdirSync(path.join(projectRoot, ".vscode"), { recursive: true })
|
||||
fs.writeFileSync(path.join(projectRoot, ".vscode", "extensions.json"), `{
|
||||
"recommendations": ["svelte.svelte-vscode"]
|
||||
}
|
||||
`)
|
||||
|
||||
console.log("Converted to TypeScript.")
|
||||
|
||||
if (fs.existsSync(path.join(projectRoot, "node_modules"))) {
|
||||
console.log("\nYou will need to re-run your dependency manager to get started.")
|
||||
}
|
6
example_client/svelte-client/src/App.svelte
Normal file
6
example_client/svelte-client/src/App.svelte
Normal file
@@ -0,0 +1,6 @@
|
||||
<script>
|
||||
import Router from 'svelte-spa-router';
|
||||
import routes from './routes.js';
|
||||
</script>
|
||||
|
||||
<Router {routes} />
|
36
example_client/svelte-client/src/api.js
Normal file
36
example_client/svelte-client/src/api.js
Normal file
@@ -0,0 +1,36 @@
|
||||
import axios from 'axios';
|
||||
|
||||
const API_BASE_URL = 'http://localhost:3000';
|
||||
|
||||
export const login = async (username, password) => {
|
||||
const formData = new FormData();
|
||||
formData.append('username', username);
|
||||
formData.append('password', password);
|
||||
|
||||
const response = await axios.post(`${API_BASE_URL}/login`, formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
withCredentials: true,
|
||||
});
|
||||
|
||||
return response.data;
|
||||
};
|
||||
|
||||
export const fetchMessages = async () => {
|
||||
const response = await axios.get(`${API_BASE_URL}/messages`, { withCredentials: true });
|
||||
return response.data;
|
||||
};
|
||||
|
||||
export const createMessage = async (body) => {
|
||||
const formData = new FormData();
|
||||
formData.append('body', body);
|
||||
|
||||
const response = await axios.post(`${API_BASE_URL}/messages/new`, formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
},
|
||||
withCredentials: true,
|
||||
});
|
||||
return response.data;
|
||||
};
|
10
example_client/svelte-client/src/main.js
Normal file
10
example_client/svelte-client/src/main.js
Normal file
@@ -0,0 +1,10 @@
|
||||
import App from './App.svelte';
|
||||
|
||||
const app = new App({
|
||||
target: document.body,
|
||||
props: {
|
||||
name: 'world'
|
||||
}
|
||||
});
|
||||
|
||||
export default app;
|
7
example_client/svelte-client/src/routes.js
Normal file
7
example_client/svelte-client/src/routes.js
Normal file
@@ -0,0 +1,7 @@
|
||||
import Login from './routes/Login.svelte';
|
||||
import Messages from './routes/Messages.svelte';
|
||||
|
||||
export default {
|
||||
'/': Login,
|
||||
'/messages': Messages,
|
||||
};
|
22
example_client/svelte-client/src/routes/Login.svelte
Normal file
22
example_client/svelte-client/src/routes/Login.svelte
Normal file
@@ -0,0 +1,22 @@
|
||||
<script>
|
||||
import { login } from '../api.js';
|
||||
import { push } from 'svelte-spa-router';
|
||||
let username = '';
|
||||
let password = '';
|
||||
let error = '';
|
||||
|
||||
const handleLogin = async () => {
|
||||
try {
|
||||
await login(username, password);
|
||||
push('/messages');
|
||||
} catch (err) {
|
||||
error = err;
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<h1>Login</h1>
|
||||
<input type="text" bind:value={username} placeholder="Username" />
|
||||
<input type="password" bind:value={password} placeholder="Password" />
|
||||
<button on:click={handleLogin}>Login</button>
|
||||
<p style="color: red">{error}</p>
|
26
example_client/svelte-client/src/routes/Messages.svelte
Normal file
26
example_client/svelte-client/src/routes/Messages.svelte
Normal file
@@ -0,0 +1,26 @@
|
||||
<script>
|
||||
import { fetchMessages, createMessage } from '../api.js';
|
||||
let messages = [];
|
||||
let newMessage = '';
|
||||
|
||||
const loadMessages = async () => {
|
||||
messages = await fetchMessages();
|
||||
};
|
||||
|
||||
const handleCreateMessage = async () => {
|
||||
await createMessage(newMessage);
|
||||
newMessage = '';
|
||||
await loadMessages();
|
||||
};
|
||||
|
||||
loadMessages();
|
||||
</script>
|
||||
|
||||
<h1>Messages</h1>
|
||||
<ul>
|
||||
{#each messages as message}
|
||||
<li>{message.user.name} - {message.body} - {message.timestamp}</li>
|
||||
{/each}
|
||||
</ul>
|
||||
<input type="text" bind:value={newMessage} placeholder="New message" />
|
||||
<button on:click={handleCreateMessage}>Send</button>
|
8
flake.lock
generated
8
flake.lock
generated
@@ -2,16 +2,16 @@
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1735264675,
|
||||
"narHash": "sha256-MgdXpeX2GuJbtlBrH9EdsUeWl/yXEubyvxM1G+yO4Ak=",
|
||||
"lastModified": 1743827369,
|
||||
"narHash": "sha256-rpqepOZ8Eo1zg+KJeWoq1HAOgoMCDloqv5r2EAa9TSA=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "d49da4c08359e3c39c4e27c74ac7ac9b70085966",
|
||||
"rev": "42a1c966be226125b48c384171c44c651c236c22",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"id": "nixpkgs",
|
||||
"ref": "nixos-24.11",
|
||||
"ref": "nixos-unstable",
|
||||
"type": "indirect"
|
||||
}
|
||||
},
|
||||
|
@@ -1,7 +1,9 @@
|
||||
{
|
||||
description = "Unnamed Chat Server API";
|
||||
|
||||
inputs.nixpkgs.url = "nixpkgs/nixos-24.11";
|
||||
inputs = {
|
||||
nixpkgs.url = "nixpkgs/nixos-unstable";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs }:
|
||||
let
|
||||
@@ -33,10 +35,11 @@
|
||||
default = pkgs.mkShell {
|
||||
hardeningDisable = [ "fortify" ];
|
||||
buildInputs = [
|
||||
pkgs.bashInteractive
|
||||
pkgs.go
|
||||
pkgs.delve
|
||||
];
|
||||
};
|
||||
});
|
||||
};
|
||||
}
|
||||
}
|
||||
|
22
go.mod
22
go.mod
@@ -1,12 +1,30 @@
|
||||
module git.dubyatp.xyz/chat-api-server
|
||||
|
||||
go 1.23
|
||||
go 1.23.0
|
||||
|
||||
toolchain go1.23.3
|
||||
|
||||
require (
|
||||
github.com/go-chi/chi/v5 v5.2.0
|
||||
github.com/go-chi/docgen v1.3.0
|
||||
github.com/go-chi/render v1.0.3
|
||||
github.com/gocql/gocql v1.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
)
|
||||
|
||||
require github.com/ajg/form v1.5.1 // indirect
|
||||
require github.com/go-chi/cors v1.2.1
|
||||
|
||||
require (
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
||||
github.com/klauspost/compress v1.17.9 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
)
|
||||
|
||||
replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.5
|
||||
|
56
go.sum
56
go.sum
@@ -1,14 +1,70 @@
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-chi/chi/v5 v5.0.1/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-chi/chi/v5 v5.2.0 h1:Aj1EtB0qR2Rdo2dG4O94RIU35w2lvQSj6BRA4+qwFL0=
|
||||
github.com/go-chi/chi/v5 v5.2.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
|
||||
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
|
||||
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
|
||||
github.com/go-chi/docgen v1.3.0 h1:dmDJ2I+EJfCTrxfgxQDwfR/OpZLTRFKe7EKB8v7yuxI=
|
||||
github.com/go-chi/docgen v1.3.0/go.mod h1:G9W0G551cs2BFMSn/cnGwX+JBHEloAgo17MBhyrnhPI=
|
||||
github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns=
|
||||
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
|
||||
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
|
||||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
||||
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/scylladb/gocql v1.14.5 h1:lyJKf0m/Vate+8MGiVeRhQNpLVVsL21gvp89zEZdltI=
|
||||
github.com/scylladb/gocql v1.14.5/go.mod h1:1efi3H0Gr72WCR0W+i+d63FmwmJhDL/zfAC0gMJHVlM=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||
golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
|
||||
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=
|
||||
|
31
log/log.go
Normal file
31
log/log.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package log
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log/slog"
|
||||
"os"
|
||||
)
|
||||
|
||||
func Logger() {
|
||||
// Get logger arguments
|
||||
var loglevelStr string
|
||||
flag.StringVar(&loglevelStr, "loglevel", "ERROR", "set log level")
|
||||
flag.Parse()
|
||||
|
||||
loglevel := new(slog.LevelVar)
|
||||
if loglevelStr == "DEBUG" {
|
||||
loglevel.Set(slog.LevelDebug)
|
||||
} else if loglevelStr == "INFO" {
|
||||
loglevel.Set(slog.LevelInfo)
|
||||
} else if loglevelStr == "WARN" {
|
||||
loglevel.Set(slog.LevelWarn)
|
||||
} else {
|
||||
loglevel.Set(slog.LevelError)
|
||||
}
|
||||
// Start logger
|
||||
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||
Level: loglevel,
|
||||
}))
|
||||
slog.SetDefault(logger)
|
||||
slog.Debug("Logging started", "level", loglevel)
|
||||
}
|
34
main.go
34
main.go
@@ -1,9 +1,43 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
|
||||
"git.dubyatp.xyz/chat-api-server/api"
|
||||
"git.dubyatp.xyz/chat-api-server/log"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
func checkEnvVars(keys []string) (bool, []string) {
|
||||
var missing []string
|
||||
for _, key := range keys {
|
||||
if _, ok := os.LookupEnv(key); !ok {
|
||||
missing = append(missing, key)
|
||||
}
|
||||
}
|
||||
return len(missing) == 0, missing
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
log.Logger() // initialize logger
|
||||
|
||||
requiredEnvVars := []string{"SCYLLA_CLUSTER", "SCYLLA_KEYSPACE"}
|
||||
|
||||
// Initialize dotenv file
|
||||
err := godotenv.Load()
|
||||
if err != nil {
|
||||
slog.Info("No .env file loaded, will try OS environment variables")
|
||||
}
|
||||
|
||||
// Check if environment variables were defined by the OS before erroring out
|
||||
exists, missingVars := checkEnvVars(requiredEnvVars)
|
||||
if !exists {
|
||||
slog.Error("Missing environment variables", "missing", missingVars)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("Starting the API Server...")
|
||||
api.Start()
|
||||
}
|
||||
|
@@ -1,44 +0,0 @@
|
||||
[
|
||||
{
|
||||
"Body": "hello",
|
||||
"ID": "1",
|
||||
"Timestamp": "2024-12-25T05:00:40Z",
|
||||
"UserID": 1
|
||||
},
|
||||
{
|
||||
"Body": "world",
|
||||
"ID": "2",
|
||||
"Timestamp": "2024-12-25T05:00:43Z",
|
||||
"UserID": 2
|
||||
},
|
||||
{
|
||||
"Body": "abababa",
|
||||
"ID": "3",
|
||||
"Timestamp": "2024-12-25T05:01:20Z",
|
||||
"UserID": 1
|
||||
},
|
||||
{
|
||||
"Body": "bitch",
|
||||
"ID": "4",
|
||||
"Timestamp": "2024-12-25T05:05:55Z",
|
||||
"UserID": 2
|
||||
},
|
||||
{
|
||||
"Body": "NIBBA",
|
||||
"ID": "5",
|
||||
"Timestamp": "2025-03-24T14:48:28.249221047-04:00",
|
||||
"UserID": 1
|
||||
},
|
||||
{
|
||||
"Body": "nibby",
|
||||
"ID": "6",
|
||||
"Timestamp": "2025-03-24T14:49:03.246929039-04:00",
|
||||
"UserID": 1
|
||||
},
|
||||
{
|
||||
"Body": "aaaaababananana",
|
||||
"ID": "msg_60f70a47-3be2-4315-869a-d6f151ec262a",
|
||||
"Timestamp": "2025-03-24T15:01:07.14371835-04:00",
|
||||
"UserID": 1
|
||||
}
|
||||
]
|
@@ -1,10 +0,0 @@
|
||||
[
|
||||
{
|
||||
"ID": 1,
|
||||
"Name": "duby"
|
||||
},
|
||||
{
|
||||
"ID": 2,
|
||||
"Name": "astolfo"
|
||||
}
|
||||
]
|
21
vendor/github.com/go-chi/cors/LICENSE
generated
vendored
Normal file
21
vendor/github.com/go-chi/cors/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
Copyright (c) 2014 Olivier Poitrey <rs@dailymotion.com>
|
||||
Copyright (c) 2016-Present https://github.com/go-chi authors
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
39
vendor/github.com/go-chi/cors/README.md
generated
vendored
Normal file
39
vendor/github.com/go-chi/cors/README.md
generated
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
# CORS net/http middleware
|
||||
|
||||
[go-chi/cors](https://github.com/go-chi/cors) is a fork of [github.com/rs/cors](https://github.com/rs/cors) that
|
||||
provides a `net/http` compatible middleware for performing preflight CORS checks on the server side. These headers
|
||||
are required for using the browser native [Fetch API](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
|
||||
|
||||
This middleware is designed to be used as a top-level middleware on the [chi](https://github.com/go-chi/chi) router.
|
||||
Applying with within a `r.Group()` or using `With()` will not work without routes matching `OPTIONS` added.
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
func main() {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Basic CORS
|
||||
// for more ideas, see: https://developer.github.com/v3/#cross-origin-resource-sharing
|
||||
r.Use(cors.Handler(cors.Options{
|
||||
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
|
||||
AllowedOrigins: []string{"https://*", "http://*"},
|
||||
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
|
||||
ExposedHeaders: []string{"Link"},
|
||||
AllowCredentials: false,
|
||||
MaxAge: 300, // Maximum value not ignored by any of major browsers
|
||||
}))
|
||||
|
||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("welcome"))
|
||||
})
|
||||
|
||||
http.ListenAndServe(":3000", r)
|
||||
}
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
All credit for the original work of this middleware goes out to [github.com/rs](github.com/rs).
|
400
vendor/github.com/go-chi/cors/cors.go
generated
vendored
Normal file
400
vendor/github.com/go-chi/cors/cors.go
generated
vendored
Normal file
@@ -0,0 +1,400 @@
|
||||
// cors package is net/http handler to handle CORS related requests
|
||||
// as defined by http://www.w3.org/TR/cors/
|
||||
//
|
||||
// You can configure it by passing an option struct to cors.New:
|
||||
//
|
||||
// c := cors.New(cors.Options{
|
||||
// AllowedOrigins: []string{"foo.com"},
|
||||
// AllowedMethods: []string{"GET", "POST", "DELETE"},
|
||||
// AllowCredentials: true,
|
||||
// })
|
||||
//
|
||||
// Then insert the handler in the chain:
|
||||
//
|
||||
// handler = c.Handler(handler)
|
||||
//
|
||||
// See Options documentation for more options.
|
||||
//
|
||||
// The resulting handler is a standard net/http handler.
|
||||
package cors
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Options is a configuration container to setup the CORS middleware.
|
||||
type Options struct {
|
||||
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
|
||||
// If the special "*" value is present in the list, all origins will be allowed.
|
||||
// An origin may contain a wildcard (*) to replace 0 or more characters
|
||||
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
|
||||
// Only one wildcard can be used per origin.
|
||||
// Default value is ["*"]
|
||||
AllowedOrigins []string
|
||||
|
||||
// AllowOriginFunc is a custom function to validate the origin. It takes the origin
|
||||
// as argument and returns true if allowed or false otherwise. If this option is
|
||||
// set, the content of AllowedOrigins is ignored.
|
||||
AllowOriginFunc func(r *http.Request, origin string) bool
|
||||
|
||||
// AllowedMethods is a list of methods the client is allowed to use with
|
||||
// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
|
||||
AllowedMethods []string
|
||||
|
||||
// AllowedHeaders is list of non simple headers the client is allowed to use with
|
||||
// cross-domain requests.
|
||||
// If the special "*" value is present in the list, all headers will be allowed.
|
||||
// Default value is [] but "Origin" is always appended to the list.
|
||||
AllowedHeaders []string
|
||||
|
||||
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
|
||||
// API specification
|
||||
ExposedHeaders []string
|
||||
|
||||
// AllowCredentials indicates whether the request can include user credentials like
|
||||
// cookies, HTTP authentication or client side SSL certificates.
|
||||
AllowCredentials bool
|
||||
|
||||
// MaxAge indicates how long (in seconds) the results of a preflight request
|
||||
// can be cached
|
||||
MaxAge int
|
||||
|
||||
// OptionsPassthrough instructs preflight to let other potential next handlers to
|
||||
// process the OPTIONS method. Turn this on if your application handles OPTIONS.
|
||||
OptionsPassthrough bool
|
||||
|
||||
// Debugging flag adds additional output to debug server side CORS issues
|
||||
Debug bool
|
||||
}
|
||||
|
||||
// Logger generic interface for logger
|
||||
type Logger interface {
|
||||
Printf(string, ...interface{})
|
||||
}
|
||||
|
||||
// Cors http handler
|
||||
type Cors struct {
|
||||
// Debug logger
|
||||
Log Logger
|
||||
|
||||
// Normalized list of plain allowed origins
|
||||
allowedOrigins []string
|
||||
|
||||
// List of allowed origins containing wildcards
|
||||
allowedWOrigins []wildcard
|
||||
|
||||
// Optional origin validator function
|
||||
allowOriginFunc func(r *http.Request, origin string) bool
|
||||
|
||||
// Normalized list of allowed headers
|
||||
allowedHeaders []string
|
||||
|
||||
// Normalized list of allowed methods
|
||||
allowedMethods []string
|
||||
|
||||
// Normalized list of exposed headers
|
||||
exposedHeaders []string
|
||||
maxAge int
|
||||
|
||||
// Set to true when allowed origins contains a "*"
|
||||
allowedOriginsAll bool
|
||||
|
||||
// Set to true when allowed headers contains a "*"
|
||||
allowedHeadersAll bool
|
||||
|
||||
allowCredentials bool
|
||||
optionPassthrough bool
|
||||
}
|
||||
|
||||
// New creates a new Cors handler with the provided options.
|
||||
func New(options Options) *Cors {
|
||||
c := &Cors{
|
||||
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
|
||||
allowOriginFunc: options.AllowOriginFunc,
|
||||
allowCredentials: options.AllowCredentials,
|
||||
maxAge: options.MaxAge,
|
||||
optionPassthrough: options.OptionsPassthrough,
|
||||
}
|
||||
if options.Debug && c.Log == nil {
|
||||
c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
|
||||
}
|
||||
|
||||
// Normalize options
|
||||
// Note: for origins and methods matching, the spec requires a case-sensitive matching.
|
||||
// As it may error prone, we chose to ignore the spec here.
|
||||
|
||||
// Allowed Origins
|
||||
if len(options.AllowedOrigins) == 0 {
|
||||
if options.AllowOriginFunc == nil {
|
||||
// Default is all origins
|
||||
c.allowedOriginsAll = true
|
||||
}
|
||||
} else {
|
||||
c.allowedOrigins = []string{}
|
||||
c.allowedWOrigins = []wildcard{}
|
||||
for _, origin := range options.AllowedOrigins {
|
||||
// Normalize
|
||||
origin = strings.ToLower(origin)
|
||||
if origin == "*" {
|
||||
// If "*" is present in the list, turn the whole list into a match all
|
||||
c.allowedOriginsAll = true
|
||||
c.allowedOrigins = nil
|
||||
c.allowedWOrigins = nil
|
||||
break
|
||||
} else if i := strings.IndexByte(origin, '*'); i >= 0 {
|
||||
// Split the origin in two: start and end string without the *
|
||||
w := wildcard{origin[0:i], origin[i+1:]}
|
||||
c.allowedWOrigins = append(c.allowedWOrigins, w)
|
||||
} else {
|
||||
c.allowedOrigins = append(c.allowedOrigins, origin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Allowed Headers
|
||||
if len(options.AllowedHeaders) == 0 {
|
||||
// Use sensible defaults
|
||||
c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
|
||||
} else {
|
||||
// Origin is always appended as some browsers will always request for this header at preflight
|
||||
c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
|
||||
for _, h := range options.AllowedHeaders {
|
||||
if h == "*" {
|
||||
c.allowedHeadersAll = true
|
||||
c.allowedHeaders = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Allowed Methods
|
||||
if len(options.AllowedMethods) == 0 {
|
||||
// Default is spec's "simple" methods
|
||||
c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead}
|
||||
} else {
|
||||
c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Handler creates a new Cors handler with passed options.
|
||||
func Handler(options Options) func(next http.Handler) http.Handler {
|
||||
c := New(options)
|
||||
return c.Handler
|
||||
}
|
||||
|
||||
// AllowAll create a new Cors handler with permissive configuration allowing all
|
||||
// origins with all standard methods with any header and credentials.
|
||||
func AllowAll() *Cors {
|
||||
return New(Options{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{
|
||||
http.MethodHead,
|
||||
http.MethodGet,
|
||||
http.MethodPost,
|
||||
http.MethodPut,
|
||||
http.MethodPatch,
|
||||
http.MethodDelete,
|
||||
},
|
||||
AllowedHeaders: []string{"*"},
|
||||
AllowCredentials: false,
|
||||
})
|
||||
}
|
||||
|
||||
// Handler apply the CORS specification on the request, and add relevant CORS headers
|
||||
// as necessary.
|
||||
func (c *Cors) Handler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
|
||||
c.logf("Handler: Preflight request")
|
||||
c.handlePreflight(w, r)
|
||||
// Preflight requests are standalone and should stop the chain as some other
|
||||
// middleware may not handle OPTIONS requests correctly. One typical example
|
||||
// is authentication middleware ; OPTIONS requests won't carry authentication
|
||||
// headers (see #1)
|
||||
if c.optionPassthrough {
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
} else {
|
||||
c.logf("Handler: Actual request")
|
||||
c.handleActualRequest(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// handlePreflight handles pre-flight CORS requests
|
||||
func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
|
||||
headers := w.Header()
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
if r.Method != http.MethodOptions {
|
||||
c.logf("Preflight aborted: %s!=OPTIONS", r.Method)
|
||||
return
|
||||
}
|
||||
// Always set Vary headers
|
||||
// see https://github.com/rs/cors/issues/10,
|
||||
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
|
||||
headers.Add("Vary", "Origin")
|
||||
headers.Add("Vary", "Access-Control-Request-Method")
|
||||
headers.Add("Vary", "Access-Control-Request-Headers")
|
||||
|
||||
if origin == "" {
|
||||
c.logf("Preflight aborted: empty origin")
|
||||
return
|
||||
}
|
||||
if !c.isOriginAllowed(r, origin) {
|
||||
c.logf("Preflight aborted: origin '%s' not allowed", origin)
|
||||
return
|
||||
}
|
||||
|
||||
reqMethod := r.Header.Get("Access-Control-Request-Method")
|
||||
if !c.isMethodAllowed(reqMethod) {
|
||||
c.logf("Preflight aborted: method '%s' not allowed", reqMethod)
|
||||
return
|
||||
}
|
||||
reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
|
||||
if !c.areHeadersAllowed(reqHeaders) {
|
||||
c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders)
|
||||
return
|
||||
}
|
||||
if c.allowedOriginsAll {
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
|
||||
// by Access-Control-Request-Method (if supported) can be enough
|
||||
headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
|
||||
if len(reqHeaders) > 0 {
|
||||
|
||||
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
|
||||
// from Access-Control-Request-Headers can be enough
|
||||
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
|
||||
}
|
||||
if c.allowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
if c.maxAge > 0 {
|
||||
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
|
||||
}
|
||||
c.logf("Preflight response headers: %v", headers)
|
||||
}
|
||||
|
||||
// handleActualRequest handles simple cross-origin requests, actual request or redirects
|
||||
func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
|
||||
headers := w.Header()
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Always set Vary, see https://github.com/rs/cors/issues/10
|
||||
headers.Add("Vary", "Origin")
|
||||
if origin == "" {
|
||||
c.logf("Actual request no headers added: missing origin")
|
||||
return
|
||||
}
|
||||
if !c.isOriginAllowed(r, origin) {
|
||||
c.logf("Actual request no headers added: origin '%s' not allowed", origin)
|
||||
return
|
||||
}
|
||||
|
||||
// Note that spec does define a way to specifically disallow a simple method like GET or
|
||||
// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
|
||||
// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
|
||||
// We think it's a nice feature to be able to have control on those methods though.
|
||||
if !c.isMethodAllowed(r.Method) {
|
||||
c.logf("Actual request no headers added: method '%s' not allowed", r.Method)
|
||||
|
||||
return
|
||||
}
|
||||
if c.allowedOriginsAll {
|
||||
headers.Set("Access-Control-Allow-Origin", "*")
|
||||
} else {
|
||||
headers.Set("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
if len(c.exposedHeaders) > 0 {
|
||||
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
|
||||
}
|
||||
if c.allowCredentials {
|
||||
headers.Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
c.logf("Actual response added headers: %v", headers)
|
||||
}
|
||||
|
||||
// convenience method. checks if a logger is set.
|
||||
func (c *Cors) logf(format string, a ...interface{}) {
|
||||
if c.Log != nil {
|
||||
c.Log.Printf(format, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
|
||||
// on the endpoint
|
||||
func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
|
||||
if c.allowOriginFunc != nil {
|
||||
return c.allowOriginFunc(r, origin)
|
||||
}
|
||||
if c.allowedOriginsAll {
|
||||
return true
|
||||
}
|
||||
origin = strings.ToLower(origin)
|
||||
for _, o := range c.allowedOrigins {
|
||||
if o == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, w := range c.allowedWOrigins {
|
||||
if w.match(origin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isMethodAllowed checks if a given method can be used as part of a cross-domain request
|
||||
// on the endpoint
|
||||
func (c *Cors) isMethodAllowed(method string) bool {
|
||||
if len(c.allowedMethods) == 0 {
|
||||
// If no method allowed, always return false, even for preflight request
|
||||
return false
|
||||
}
|
||||
method = strings.ToUpper(method)
|
||||
if method == http.MethodOptions {
|
||||
// Always allow preflight requests
|
||||
return true
|
||||
}
|
||||
for _, m := range c.allowedMethods {
|
||||
if m == method {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// areHeadersAllowed checks if a given list of headers are allowed to used within
|
||||
// a cross-domain request.
|
||||
func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
|
||||
if c.allowedHeadersAll || len(requestedHeaders) == 0 {
|
||||
return true
|
||||
}
|
||||
for _, header := range requestedHeaders {
|
||||
header = http.CanonicalHeaderKey(header)
|
||||
found := false
|
||||
for _, h := range c.allowedHeaders {
|
||||
if h == header {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
70
vendor/github.com/go-chi/cors/utils.go
generated
vendored
Normal file
70
vendor/github.com/go-chi/cors/utils.go
generated
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
package cors
|
||||
|
||||
import "strings"
|
||||
|
||||
const toLower = 'a' - 'A'
|
||||
|
||||
type converter func(string) string
|
||||
|
||||
type wildcard struct {
|
||||
prefix string
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (w wildcard) match(s string) bool {
|
||||
return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix)
|
||||
}
|
||||
|
||||
// convert converts a list of string using the passed converter function
|
||||
func convert(s []string, c converter) []string {
|
||||
out := []string{}
|
||||
for _, i := range s {
|
||||
out = append(out, c(i))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// parseHeaderList tokenize + normalize a string containing a list of headers
|
||||
func parseHeaderList(headerList string) []string {
|
||||
l := len(headerList)
|
||||
h := make([]byte, 0, l)
|
||||
upper := true
|
||||
// Estimate the number headers in order to allocate the right splice size
|
||||
t := 0
|
||||
for i := 0; i < l; i++ {
|
||||
if headerList[i] == ',' {
|
||||
t++
|
||||
}
|
||||
}
|
||||
headers := make([]string, 0, t)
|
||||
for i := 0; i < l; i++ {
|
||||
b := headerList[i]
|
||||
if b >= 'a' && b <= 'z' {
|
||||
if upper {
|
||||
h = append(h, b-toLower)
|
||||
} else {
|
||||
h = append(h, b)
|
||||
}
|
||||
} else if b >= 'A' && b <= 'Z' {
|
||||
if !upper {
|
||||
h = append(h, b+toLower)
|
||||
} else {
|
||||
h = append(h, b)
|
||||
}
|
||||
} else if b == '-' || b == '_' || b == '.' || (b >= '0' && b <= '9') {
|
||||
h = append(h, b)
|
||||
}
|
||||
|
||||
if b == ' ' || b == ',' || i == l-1 {
|
||||
if len(h) > 0 {
|
||||
// Flush the found header
|
||||
headers = append(headers, string(h))
|
||||
h = h[:0]
|
||||
upper = true
|
||||
}
|
||||
} else {
|
||||
upper = b == '-'
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
5
vendor/github.com/gocql/gocql/.gitignore
generated
vendored
Normal file
5
vendor/github.com/gocql/gocql/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
gocql-fuzz
|
||||
fuzz-corpus
|
||||
fuzz-work
|
||||
gocql.test
|
||||
.idea
|
148
vendor/github.com/gocql/gocql/AUTHORS
generated
vendored
Normal file
148
vendor/github.com/gocql/gocql/AUTHORS
generated
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
# This source file refers to The gocql Authors for copyright purposes.
|
||||
|
||||
Christoph Hack <christoph@tux21b.org>
|
||||
Jonathan Rudenberg <jonathan@titanous.com>
|
||||
Thorsten von Eicken <tve@rightscale.com>
|
||||
Matt Robenolt <mattr@disqus.com>
|
||||
Phillip Couto <phillip.couto@stemstudios.com>
|
||||
Niklas Korz <korz.niklask@gmail.com>
|
||||
Nimi Wariboko Jr <nimi@channelmeter.com>
|
||||
Ghais Issa <ghais.issa@gmail.com>
|
||||
Sasha Klizhentas <klizhentas@gmail.com>
|
||||
Konstantin Cherkasov <k.cherkasoff@gmail.com>
|
||||
Ben Hood <0x6e6562@gmail.com>
|
||||
Pete Hopkins <phopkins@gmail.com>
|
||||
Chris Bannister <c.bannister@gmail.com>
|
||||
Maxim Bublis <b@codemonkey.ru>
|
||||
Alex Zorin <git@zor.io>
|
||||
Kasper Middelboe Petersen <me@phant.dk>
|
||||
Harpreet Sawhney <harpreet.sawhney@gmail.com>
|
||||
Charlie Andrews <charlieandrews.cwa@gmail.com>
|
||||
Stanislavs Koikovs <stanislavs.koikovs@gmail.com>
|
||||
Dan Forest <bonjour@dan.tf>
|
||||
Miguel Serrano <miguelvps@gmail.com>
|
||||
Stefan Radomski <gibheer@zero-knowledge.org>
|
||||
Josh Wright <jshwright@gmail.com>
|
||||
Jacob Rhoden <jacob.rhoden@gmail.com>
|
||||
Ben Frye <benfrye@gmail.com>
|
||||
Fred McCann <fred@sharpnoodles.com>
|
||||
Dan Simmons <dan@simmons.io>
|
||||
Muir Manders <muir@retailnext.net>
|
||||
Sankar P <sankar.curiosity@gmail.com>
|
||||
Julien Da Silva <julien.dasilva@gmail.com>
|
||||
Dan Kennedy <daniel@firstcs.co.uk>
|
||||
Nick Dhupia<nick.dhupia@gmail.com>
|
||||
Yasuharu Goto <matope.ono@gmail.com>
|
||||
Jeremy Schlatter <jeremy.schlatter@gmail.com>
|
||||
Matthias Kadenbach <matthias.kadenbach@gmail.com>
|
||||
Dean Elbaz <elbaz.dean@gmail.com>
|
||||
Mike Berman <evencode@gmail.com>
|
||||
Dmitriy Fedorenko <c0va23@gmail.com>
|
||||
Zach Marcantel <zmarcantel@gmail.com>
|
||||
James Maloney <jamessagan@gmail.com>
|
||||
Ashwin Purohit <purohit@gmail.com>
|
||||
Dan Kinder <dkinder.is.me@gmail.com>
|
||||
Oliver Beattie <oliver@obeattie.com>
|
||||
Justin Corpron <jncorpron@gmail.com>
|
||||
Miles Delahunty <miles.delahunty@gmail.com>
|
||||
Zach Badgett <zach.badgett@gmail.com>
|
||||
Maciek Sakrejda <maciek@heroku.com>
|
||||
Jeff Mitchell <jeffrey.mitchell@gmail.com>
|
||||
Baptiste Fontaine <b@ptistefontaine.fr>
|
||||
Matt Heath <matt@mattheath.com>
|
||||
Jamie Cuthill <jamie.cuthill@gmail.com>
|
||||
Adrian Casajus <adriancasajus@gmail.com>
|
||||
John Weldon <johnweldon4@gmail.com>
|
||||
Adrien Bustany <adrien@bustany.org>
|
||||
Andrey Smirnov <smirnov.andrey@gmail.com>
|
||||
Adam Weiner <adamsweiner@gmail.com>
|
||||
Daniel Cannon <daniel@danielcannon.co.uk>
|
||||
Johnny Bergström <johnny@joonix.se>
|
||||
Adriano Orioli <orioli.adriano@gmail.com>
|
||||
Claudiu Raveica <claudiu.raveica@gmail.com>
|
||||
Artem Chernyshev <artem.0xD2@gmail.com>
|
||||
Ference Fu <fym201@msn.com>
|
||||
LOVOO <opensource@lovoo.com>
|
||||
nikandfor <nikandfor@gmail.com>
|
||||
Anthony Woods <awoods@raintank.io>
|
||||
Alexander Inozemtsev <alexander.inozemtsev@gmail.com>
|
||||
Rob McColl <rob@robmccoll.com>; <rmccoll@ionicsecurity.com>
|
||||
Viktor Tönköl <viktor.toenkoel@motionlogic.de>
|
||||
Ian Lozinski <ian.lozinski@gmail.com>
|
||||
Michael Highstead <highstead@gmail.com>
|
||||
Sarah Brown <esbie.is@gmail.com>
|
||||
Caleb Doxsey <caleb@datadoghq.com>
|
||||
Frederic Hemery <frederic.hemery@datadoghq.com>
|
||||
Pekka Enberg <penberg@scylladb.com>
|
||||
Mark M <m.mim95@gmail.com>
|
||||
Bartosz Burclaf <burclaf@gmail.com>
|
||||
Marcus King <marcusking01@gmail.com>
|
||||
Andrew de Andrade <andrew@deandrade.com.br>
|
||||
Robert Nix <robert@nicerobot.org>
|
||||
Nathan Youngman <git@nathany.com>
|
||||
Charles Law <charles.law@gmail.com>; <claw@conduce.com>
|
||||
Nathan Davies <nathanjamesdavies@gmail.com>
|
||||
Bo Blanton <bo.blanton@gmail.com>
|
||||
Vincent Rischmann <me@vrischmann.me>
|
||||
Jesse Claven <jesse.claven@gmail.com>
|
||||
Derrick Wippler <thrawn01@gmail.com>
|
||||
Leigh McCulloch <leigh@leighmcculloch.com>
|
||||
Ron Kuris <swcafe@gmail.com>
|
||||
Raphael Gavache <raphael.gavache@gmail.com>
|
||||
Yasser Abdolmaleki <yasser@yasser.ca>
|
||||
Krishnanand Thommandra <devtkrishna@gmail.com>
|
||||
Blake Atkinson <me@blakeatkinson.com>
|
||||
Dharmendra Parsaila <d4dharmu@gmail.com>
|
||||
Nayef Ghattas <nayef.ghattas@datadoghq.com>
|
||||
Michał Matczuk <mmatczuk@gmail.com>
|
||||
Ben Krebsbach <ben.krebsbach@gmail.com>
|
||||
Vivian Mathews <vivian.mathews.3@gmail.com>
|
||||
Sascha Steinbiss <satta@debian.org>
|
||||
Seth Rosenblum <seth.t.rosenblum@gmail.com>
|
||||
Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
|
||||
Luke Hines <lukehines@protonmail.com>
|
||||
Zhixin Wen <john.wenzhixin@hotmail.com>
|
||||
Chang Liu <changliu.it@gmail.com>
|
||||
Ingo Oeser <nightlyone@gmail.com>
|
||||
Luke Hines <lukehines@protonmail.com>
|
||||
Jacob Greenleaf <jacob@jacobgreenleaf.com>
|
||||
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
|
||||
Marco Cadetg <cadetg@gmail.com>
|
||||
Karl Matthias <karl@matthias.org>
|
||||
Thomas Meson <zllak@hycik.org>
|
||||
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
|
||||
Pavel Buchinchik <p.buchinchik@gmail.com>
|
||||
Rintaro Okamura <rintaro.okamura@gmail.com>
|
||||
Ivan Boyarkin <ivan.boyarkin@kiwi.com>; <mr.vanboy@gmail.com>
|
||||
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
|
||||
Jorge Bay <jorgebg@apache.org>
|
||||
Dmitriy Kozlov <hummerd@mail.ru>
|
||||
Alexey Romanovsky <alexus1024+gocql@gmail.com>
|
||||
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
|
||||
Piotr Dulikowski <piodul@scylladb.com>
|
||||
Árni Dagur <arni@dagur.eu>
|
||||
Tushar Das <tushar.das5@gmail.com>
|
||||
Maxim Vladimirskiy <horkhe@gmail.com>
|
||||
Bogdan-Ciprian Rusu <bogdanciprian.rusu@crowdstrike.com>
|
||||
Yuto Doi <yutodoi.seattle@gmail.com>
|
||||
Krishna Vadali <tejavadali@gmail.com>
|
||||
Jens-W. Schicke-Uffmann <drahflow@gmx.de>
|
||||
Ondrej Polakovič <ondrej.polakovic@kiwi.com>
|
||||
Sergei Karetnikov <sergei.karetnikov@gmail.com>
|
||||
Stefan Miklosovic <smiklosovic@apache.org>
|
||||
Adam Burk <amburk@gmail.com>
|
||||
Valerii Ponomarov <kiparis.kh@gmail.com>
|
||||
Neal Turett <neal.turett@datadoghq.com>
|
||||
Doug Schaapveld <djschaap@gmail.com>
|
||||
Steven Seidman <steven.seidman@datadoghq.com>
|
||||
Wojciech Przytuła <wojciech.przytula@scylladb.com>
|
||||
João Reis <joao.reis@datastax.com>
|
||||
Lauro Ramos Venancio <lauro.venancio@incognia.com>
|
||||
Dmitry Kropachev <dmitry.kropachev@gmail.com>
|
||||
Oliver Boyle <pleasedontspamme4321+gocql@gmail.com>
|
||||
Jackson Fleming <jackson.fleming@instaclustr.com>
|
||||
Sylwia Szunejko <sylwia.szunejko@scylladb.com>
|
||||
Karol Baryła <karol.baryla@scylladb.com>
|
||||
Marcin Mazurek <marcinek.mazurek@gmail.com>
|
||||
Moguchev Leonid Alekseevich <lmoguchev@ozon.ru>
|
||||
Julien Lefevre <julien.lefevr@gmail.com>
|
78
vendor/github.com/gocql/gocql/CONTRIBUTING.md
generated
vendored
Normal file
78
vendor/github.com/gocql/gocql/CONTRIBUTING.md
generated
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
# Contributing to gocql
|
||||
|
||||
**TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql.
|
||||
|
||||
This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base.
|
||||
|
||||
## Background
|
||||
|
||||
The goal of the gocql project is to provide a stable and robust CQL driver for Go. gocql is a community driven project that is coordinated by a small team of core developers.
|
||||
|
||||
## Minimum Requirement Checklist
|
||||
|
||||
The following is a check list of requirements that need to be satisfied in order for us to merge your patch:
|
||||
|
||||
* You should raise a pull request to gocql/gocql on Github
|
||||
* The pull request has a title that clearly summarizes the purpose of the patch
|
||||
* The motivation behind the patch is clearly defined in the pull request summary
|
||||
* Your name and email have been added to the `AUTHORS` file (for copyright purposes)
|
||||
* The patch will merge cleanly
|
||||
* The test coverage does not fall below the critical threshold (currently 64%)
|
||||
* The merge commit passes the regression test suite on Travis
|
||||
* `go fmt` has been applied to the submitted code
|
||||
* Notable changes (i.e. new features or changed behavior, bugfixes) are appropriately documented in CHANGELOG.md, functional changes also in godoc
|
||||
|
||||
If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements.
|
||||
|
||||
## Beyond The Checklist
|
||||
|
||||
In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss.
|
||||
|
||||
### General QA Approach
|
||||
|
||||
The gocql team needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted.
|
||||
|
||||
Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce.
|
||||
|
||||
Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis.
|
||||
|
||||
That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there.
|
||||
|
||||
### When It's Too Difficult To Automate Testing
|
||||
|
||||
There are legitimate examples of where it is infeasible to write a regression test for a change. Never fear, we will still consider the patch and quite possibly accept the change without a test. The gocql team takes a pragmatic approach to testing. At the end of the day, you could be addressing an issue that is too difficult to reproduce in a test suite, but still occurs in a real production app. In this case, your production app is the test case, and we will have to trust that your change is good.
|
||||
|
||||
Examples of pull requests that have been accepted without tests include:
|
||||
|
||||
* https://github.com/gocql/gocql/pull/181 - this patch would otherwise require a multi-node cluster to be booted as part of the CI build
|
||||
* https://github.com/gocql/gocql/pull/179 - this bug can only be reproduced under heavy load in certain circumstances
|
||||
|
||||
### Sign Off Procedure
|
||||
|
||||
Generally speaking, a pull request can get merged by any one of the core gocql team. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process.
|
||||
|
||||
### Supported Features
|
||||
|
||||
gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation with the gocql team. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way.
|
||||
|
||||
### Longer Term Evolution
|
||||
|
||||
There are some long term plans for gocql that have to be taken into account when assessing changes. That said, gocql is ultimately a community driven project and we don't have a massive development budget, so sometimes the long term view might need to be de-prioritized ahead of short term changes.
|
||||
|
||||
## Officially Supported Server Versions
|
||||
|
||||
Currently, the officially supported versions of the Cassandra server include:
|
||||
|
||||
* 1.2.18
|
||||
* 2.0.9
|
||||
|
||||
Chances are that gocql will work with many other versions. If you would like us to support a particular version of Cassandra, please start a conversation about what version you'd like us to consider. We are more likely to accept a new version if you help out by extending the regression suite to cover the new version to be supported.
|
||||
|
||||
## The Core Dev Team
|
||||
|
||||
The core development team includes:
|
||||
|
||||
* tux21b
|
||||
* phillipCouto
|
||||
* Zariel
|
||||
* 0x6e6562
|
27
vendor/github.com/gocql/gocql/LICENSE
generated
vendored
Normal file
27
vendor/github.com/gocql/gocql/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Copyright (c) 2016, The Gocql authors
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
5
vendor/github.com/gocql/gocql/Makefile
generated
vendored
Normal file
5
vendor/github.com/gocql/gocql/Makefile
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# Makefile to run the Docker cleanup script
|
||||
|
||||
clean-old-temporary-docker-images:
|
||||
@echo "Running Docker Hub image cleanup script..."
|
||||
python ci/clean-old-temporary-docker-images.py
|
242
vendor/github.com/gocql/gocql/README.md
generated
vendored
Normal file
242
vendor/github.com/gocql/gocql/README.md
generated
vendored
Normal file
@@ -0,0 +1,242 @@
|
||||
<div align="center">
|
||||
|
||||

|
||||
[](https://pkg.go.dev/github.com/scylladb/gocql#section-documentation)
|
||||
[](https://github.com/scylladb/scylladb/blob/master/docs/dev/protocol-extensions.md)
|
||||
|
||||
</div>
|
||||
|
||||
<h1 align="center">
|
||||
|
||||
Scylla Shard-Aware Fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver)
|
||||
|
||||
</h1>
|
||||
|
||||
|
||||
<img src="./.github/assets/logo.svg" width="200" align="left" />
|
||||
|
||||
This is a fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver) package that we created at Scylla.
|
||||
It contains extensions to tokenAwareHostPolicy supported by the Scylla 2.3 and onwards.
|
||||
It allows driver to select a connection to a particular shard on a host based on the token.
|
||||
This eliminates passing data between shards and significantly reduces latency.
|
||||
|
||||
There are open pull requests to merge the functionality to the upstream project:
|
||||
|
||||
* [gocql/gocql#1210](https://github.com/gocql/gocql/pull/1210)
|
||||
* [gocql/gocql#1211](https://github.com/gocql/gocql/pull/1211).
|
||||
|
||||
It also provides support for shard aware ports, a faster way to connect to all shards, details available in [blogpost](https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/).
|
||||
|
||||
---
|
||||
|
||||
### Table of Contents
|
||||
|
||||
- [1. Sunsetting Model](#1-sunsetting-model)
|
||||
- [2. Installation](#2-installation)
|
||||
- [3. Quick Start](#3-quick-start)
|
||||
- [4. Data Types](#4-data-types)
|
||||
- [5. Configuration](#5-configuration)
|
||||
- [5.1 Shard-aware port](#51-shard-aware-port)
|
||||
- [5.2 Iterator](#52-iterator)
|
||||
- [6. Contributing](#6-contributing)
|
||||
|
||||
## 1. Sunsetting Model
|
||||
|
||||
> [!WARNING]
|
||||
> In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset.
|
||||
|
||||
## 2. Installation
|
||||
|
||||
This is a drop-in replacement to gocql, it reuses the `github.com/gocql/gocql` import path.
|
||||
|
||||
Add the following line to your project `go.mod` file.
|
||||
|
||||
```mod
|
||||
replace github.com/gocql/gocql => github.com/scylladb/gocql latest
|
||||
```
|
||||
|
||||
and run
|
||||
|
||||
```sh
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
to evaluate `latest` to a concrete tag.
|
||||
|
||||
Your project now uses the Scylla driver fork, make sure you are using the `TokenAwareHostPolicy` to enable the shard-awareness, continue reading for details.
|
||||
|
||||
## 3. Quick Start
|
||||
|
||||
Spawn a ScyllaDB Instance using Docker Run command:
|
||||
|
||||
```sh
|
||||
docker run --name node1 --network your-network -p "9042:9042" -d scylladb/scylla:6.1.2 \
|
||||
--overprovisioned 1 \
|
||||
--smp 1
|
||||
```
|
||||
|
||||
Then, create a new connection using ScyllaDB GoCQL following the example below:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var cluster = gocql.NewCluster("localhost:9042")
|
||||
|
||||
var session, err = cluster.CreateSession()
|
||||
if err != nil {
|
||||
panic("Failed to connect to cluster")
|
||||
}
|
||||
|
||||
defer session.Close()
|
||||
|
||||
var query = session.Query("SELECT * FROM system.clients")
|
||||
|
||||
if rows, err := query.Iter().SliceMap(); err == nil {
|
||||
for _, row := range rows {
|
||||
fmt.Printf("%v\n", row)
|
||||
}
|
||||
} else {
|
||||
panic("Query error: " + err.Error())
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Data Types
|
||||
|
||||
Here's an list of all ScyllaDB Types reflected in the GoCQL environment:
|
||||
|
||||
| ScyllaDB Type | Go Type |
|
||||
| ---------------- | ------------------ |
|
||||
| `ascii` | `string` |
|
||||
| `bigint` | `int64` |
|
||||
| `blob` | `[]byte` |
|
||||
| `boolean` | `bool` |
|
||||
| `date` | `time.Time` |
|
||||
| `decimal` | `inf.Dec` |
|
||||
| `double` | `float64` |
|
||||
| `duration` | `gocql.Duration` |
|
||||
| `float` | `float32` |
|
||||
| `uuid` | `gocql.UUID` |
|
||||
| `int` | `int32` |
|
||||
| `inet` | `string` |
|
||||
| `list<int>` | `[]int32` |
|
||||
| `map<int, text>` | `map[int32]string` |
|
||||
| `set<int>` | `[]int32` |
|
||||
| `smallint` | `int16` |
|
||||
| `text` | `string` |
|
||||
| `time` | `time.Duration` |
|
||||
| `timestamp` | `time.Time` |
|
||||
| `timeuuid` | `gocql.UUID` |
|
||||
| `tinyint` | `int8` |
|
||||
| `varchar` | `string` |
|
||||
| `varint` | `int64` |
|
||||
|
||||
## 5. Configuration
|
||||
|
||||
In order to make shard-awareness work, token aware host selection policy has to be enabled.
|
||||
Please make sure that the gocql configuration has `PoolConfig.HostSelectionPolicy` properly set like in the example below.
|
||||
|
||||
__When working with a Scylla cluster, `PoolConfig.NumConns` option has no effect - the driver opens one connection for each shard and completely ignores this option.__
|
||||
|
||||
```go
|
||||
c := gocql.NewCluster(hosts...)
|
||||
|
||||
// Enable token aware host selection policy, if using multi-dc cluster set a local DC.
|
||||
fallback := gocql.RoundRobinHostPolicy()
|
||||
if localDC != "" {
|
||||
fallback = gocql.DCAwareRoundRobinPolicy(localDC)
|
||||
}
|
||||
c.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(fallback)
|
||||
|
||||
// If using multi-dc cluster use the "local" consistency levels.
|
||||
if localDC != "" {
|
||||
c.Consistency = gocql.LocalQuorum
|
||||
}
|
||||
|
||||
// When working with a Scylla cluster the driver always opens one connection per shard, so `NumConns` is ignored.
|
||||
// c.NumConns = 4
|
||||
```
|
||||
|
||||
### 5.1 Shard-aware port
|
||||
|
||||
This version of gocql supports a more robust method of establishing connection for each shard by using _shard aware port_ for native transport.
|
||||
It greatly reduces time and the number of connections needed to establish a connection per shard in some cases - ex. when many clients connect at once, or when there are non-shard-aware clients connected to the same cluster.
|
||||
|
||||
If you are using a custom Dialer and if your nodes expose the shard-aware port, it is highly recommended to update it so that it uses a specific source port when connecting.
|
||||
|
||||
* If you are using a custom `net.Dialer`, you can make your dialer honor the source port by wrapping it in a `gocql.ScyllaShardAwareDialer`:
|
||||
|
||||
```go
|
||||
oldDialer := net.Dialer{...}
|
||||
clusterConfig.Dialer := &gocql.ScyllaShardAwareDialer{oldDialer}
|
||||
```
|
||||
|
||||
* If you are using a custom type implementing `gocql.Dialer`, you can get the source port by using the `gocql.ScyllaGetSourcePort` function.
|
||||
An example:
|
||||
|
||||
```go
|
||||
func (d *myDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
sourcePort := gocql.ScyllaGetSourcePort(ctx)
|
||||
localAddr, err := net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &net.Dialer{LocalAddr: localAddr}
|
||||
return d.DialContext(ctx, network, addr)
|
||||
}
|
||||
```
|
||||
|
||||
The source port might be already bound by another connection on your system.
|
||||
In such case, you should return an appropriate error so that the driver can retry with a different port suitable for the shard it tries to connect to.
|
||||
|
||||
* If you are using `net.Dialer.DialContext`, this function will return an error in case the source port is unavailable, and you can just return that error from your custom `Dialer`.
|
||||
* Otherwise, if you detect that the source port is unavailable, you can either return `gocql.ErrScyllaSourcePortAlreadyInUse` or `syscall.EADDRINUSE`.
|
||||
|
||||
For this feature to work correctly, you need to make sure the following conditions are met:
|
||||
|
||||
* Your cluster nodes are configured to listen on the shard-aware port (`native_shard_aware_transport_port` option),
|
||||
* Your cluster nodes are not behind a NAT which changes source ports,
|
||||
* If you have a custom Dialer, it connects from the correct source port (see the guide above).
|
||||
|
||||
The feature is designed to gracefully fall back to the using the non-shard-aware port when it detects that some of the above conditions are not met.
|
||||
The driver will print a warning about misconfigured address translation if it detects it.
|
||||
Issues with shard-aware port not being reachable are not reported in non-debug mode, because there is no way to detect it without false positives.
|
||||
|
||||
If you suspect that this feature is causing you problems, you can completely disable it by setting the `ClusterConfig.DisableShardAwarePort` flag to true.
|
||||
|
||||
### 5.2 Iterator
|
||||
|
||||
Paging is a way to parse large result sets in smaller chunks.
|
||||
The driver provides an iterator to simplify this process.
|
||||
|
||||
Use `Query.Iter()` to obtain iterator:
|
||||
|
||||
```go
|
||||
iter := session.Query("SELECT id, value FROM my_table WHERE id > 100 AND id < 10000").Iter()
|
||||
var results []int
|
||||
|
||||
var id, value int
|
||||
for !iter.Scan(&id, &value) {
|
||||
if id%2 == 0 {
|
||||
results = append(results, value)
|
||||
}
|
||||
}
|
||||
|
||||
if err := iter.Close(); err != nil {
|
||||
// handle error
|
||||
}
|
||||
```
|
||||
|
||||
In case of range and `ALLOW FILTERING` queries server can send empty responses for some pages.
|
||||
That is why you should never consider empty response as the end of the result set.
|
||||
Always check `iter.Scan()` result to know if there are more results, or `Iter.LastPage()` to know if the last page was reached.
|
||||
|
||||
## 6. Contributing
|
||||
|
||||
If you have any interest to be contributing in this GoCQL Fork, please read the [CONTRIBUTING.md](CONTRIBUTING.md) before initialize any Issue or Pull Request.
|
26
vendor/github.com/gocql/gocql/address_translators.go
generated
vendored
Normal file
26
vendor/github.com/gocql/gocql/address_translators.go
generated
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
package gocql
|
||||
|
||||
import "net"
|
||||
|
||||
// AddressTranslator provides a way to translate node addresses (and ports) that are
|
||||
// discovered or received as a node event. This can be useful in an ec2 environment,
|
||||
// for instance, to translate public IPs to private IPs.
|
||||
type AddressTranslator interface {
|
||||
// Translate will translate the provided address and/or port to another
|
||||
// address and/or port. If no translation is possible, Translate will return the
|
||||
// address and port provided to it.
|
||||
Translate(addr net.IP, port int) (net.IP, int)
|
||||
}
|
||||
|
||||
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
|
||||
|
||||
func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
|
||||
return fn(addr, port)
|
||||
}
|
||||
|
||||
// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
|
||||
func IdentityTranslator() AddressTranslator {
|
||||
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
|
||||
return addr, port
|
||||
})
|
||||
}
|
541
vendor/github.com/gocql/gocql/cluster.go
generated
vendored
Normal file
541
vendor/github.com/gocql/gocql/cluster.go
generated
vendored
Normal file
@@ -0,0 +1,541 @@
|
||||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultDriverName = "ScyllaDB GoCQL Driver"
|
||||
|
||||
// PoolConfig configures the connection pool used by the driver, it defaults to
|
||||
// using a round-robin host selection policy and a round-robin connection selection
|
||||
// policy for each host.
|
||||
type PoolConfig struct {
|
||||
// HostSelectionPolicy sets the policy for selecting which host to use for a
|
||||
// given query (default: RoundRobinHostPolicy())
|
||||
// It is not supported to use a single HostSelectionPolicy in multiple sessions
|
||||
// (even if you close the old session before using in a new session).
|
||||
HostSelectionPolicy HostSelectionPolicy
|
||||
}
|
||||
|
||||
func (p PoolConfig) buildPool(session *Session) *policyConnPool {
|
||||
return newPolicyConnPool(session)
|
||||
}
|
||||
|
||||
// ClusterConfig is a struct to configure the default cluster implementation
|
||||
// of gocql. It has a variety of attributes that can be used to modify the
|
||||
// behavior to fit the most common use cases. Applications that require a
|
||||
// different setup must implement their own cluster.
|
||||
type ClusterConfig struct {
|
||||
// addresses for the initial connections. It is recommended to use the value set in
|
||||
// the Cassandra config for broadcast_address or listen_address, an IP address not
|
||||
// a domain name. This is because events from Cassandra will use the configured IP
|
||||
// address, which is used to index connected hosts. If the domain name specified
|
||||
// resolves to more than 1 IP address then the driver may connect multiple times to
|
||||
// the same host, and will not mark the node being down or up from events.
|
||||
Hosts []string
|
||||
|
||||
// CQL version (default: 3.0.0)
|
||||
CQLVersion string
|
||||
|
||||
// ProtoVersion sets the version of the native protocol to use, this will
|
||||
// enable features in the driver for specific protocol versions, generally this
|
||||
// should be set to a known version (2,3,4) for the cluster being connected to.
|
||||
//
|
||||
// If it is 0 or unset (the default) then the driver will attempt to discover the
|
||||
// highest supported protocol for the cluster. In clusters with nodes of different
|
||||
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
|
||||
ProtoVersion int
|
||||
|
||||
// Timeout limits the time spent on the client side while executing a query.
|
||||
// Specifically, query or batch execution will return an error if the client does not receive a response
|
||||
// from the server within the Timeout period.
|
||||
// Timeout is also used to configure the read timeout on the underlying network connection.
|
||||
// Client Timeout should always be higher than the request timeouts configured on the server,
|
||||
// so that retries don't overload the server.
|
||||
// Timeout has a default value of 11 seconds, which is higher than default server timeout for most query types.
|
||||
// Timeout is not applied to requests during initial connection setup, see ConnectTimeout.
|
||||
Timeout time.Duration
|
||||
|
||||
// ConnectTimeout limits the time spent during connection setup.
|
||||
// During initial connection setup, internal queries, AUTH requests will return an error if the client
|
||||
// does not receive a response within the ConnectTimeout period.
|
||||
// ConnectTimeout is applied to the connection setup queries independently.
|
||||
// ConnectTimeout also limits the duration of dialing a new TCP connection
|
||||
// in case there is no Dialer nor HostDialer configured.
|
||||
// ConnectTimeout has a default value of 11 seconds.
|
||||
ConnectTimeout time.Duration
|
||||
|
||||
// WriteTimeout limits the time the driver waits to write a request to a network connection.
|
||||
// WriteTimeout should be lower than or equal to Timeout.
|
||||
// WriteTimeout defaults to the value of Timeout.
|
||||
WriteTimeout time.Duration
|
||||
|
||||
// Port used when dialing.
|
||||
// Default: 9042
|
||||
Port int
|
||||
|
||||
// Initial keyspace. Optional.
|
||||
Keyspace string
|
||||
|
||||
// The size of the connection pool for each host.
|
||||
// The pool filling runs in separate gourutine during the session initialization phase.
|
||||
// gocql will always try to get 1 connection on each host pool
|
||||
// during session initialization AND it will attempt
|
||||
// to fill each pool afterward asynchronously if NumConns > 1.
|
||||
// Notice: There is no guarantee that pool filling will be finished in the initialization phase.
|
||||
// Also, it describes a maximum number of connections at the same time.
|
||||
// Default: 2
|
||||
NumConns int
|
||||
|
||||
// Maximum number of inflight requests allowed per connection.
|
||||
// Default: 32768 for CQL v3 and newer
|
||||
// Default: 128 for older CQL versions
|
||||
MaxRequestsPerConn int
|
||||
|
||||
// Default consistency level.
|
||||
// Default: Quorum
|
||||
Consistency Consistency
|
||||
|
||||
// Compression algorithm.
|
||||
// Default: nil
|
||||
Compressor Compressor
|
||||
|
||||
// Default: nil
|
||||
Authenticator Authenticator
|
||||
|
||||
WarningsHandlerBuilder WarningHandlerBuilder
|
||||
|
||||
// An Authenticator factory. Can be used to create alternative authenticators.
|
||||
// Default: nil
|
||||
AuthProvider func(h *HostInfo) (Authenticator, error)
|
||||
|
||||
// Default retry policy to use for queries.
|
||||
// Default: no retries.
|
||||
RetryPolicy RetryPolicy
|
||||
|
||||
// ConvictionPolicy decides whether to mark host as down based on the error and host info.
|
||||
// Default: SimpleConvictionPolicy
|
||||
ConvictionPolicy ConvictionPolicy
|
||||
|
||||
// Default reconnection policy to use for reconnecting before trying to mark host as down.
|
||||
ReconnectionPolicy ReconnectionPolicy
|
||||
|
||||
// A reconnection policy to use for reconnecting when connecting to the cluster first time.
|
||||
InitialReconnectionPolicy ReconnectionPolicy
|
||||
|
||||
// The keepalive period to use, enabled if > 0 (default: 15 seconds)
|
||||
// SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
|
||||
SocketKeepalive time.Duration
|
||||
|
||||
// Maximum cache size for prepared statements globally for gocql.
|
||||
// Default: 1000
|
||||
MaxPreparedStmts int
|
||||
|
||||
// Maximum cache size for query info about statements for each session.
|
||||
// Default: 1000
|
||||
MaxRoutingKeyInfo int
|
||||
|
||||
// Default page size to use for created sessions.
|
||||
// Default: 5000
|
||||
PageSize int
|
||||
|
||||
// Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL.
|
||||
// Default: unset
|
||||
SerialConsistency SerialConsistency
|
||||
|
||||
// SslOpts configures TLS use when HostDialer is not set.
|
||||
// SslOpts is ignored if HostDialer is set.
|
||||
SslOpts *SslOptions
|
||||
actualSslOpts atomic.Value
|
||||
|
||||
// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
|
||||
// Default: true, only enabled for protocol 3 and above.
|
||||
DefaultTimestamp bool
|
||||
|
||||
// The name of the driver that is going to be reported to the server.
|
||||
// Default: "ScyllaDB GoLang Driver"
|
||||
DriverName string
|
||||
|
||||
// The version of the driver that is going to be reported to the server.
|
||||
// Defaulted to current library version
|
||||
DriverVersion string
|
||||
|
||||
// PoolConfig configures the underlying connection pool, allowing the
|
||||
// configuration of host selection and connection selection policies.
|
||||
PoolConfig PoolConfig
|
||||
|
||||
// If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval.
|
||||
ReconnectInterval time.Duration
|
||||
|
||||
// The maximum amount of time to wait for schema agreement in a cluster after
|
||||
// receiving a schema change frame. (default: 60s)
|
||||
MaxWaitSchemaAgreement time.Duration
|
||||
|
||||
// HostFilter will filter all incoming events for host, any which don't pass
|
||||
// the filter will be ignored. If set will take precedence over any options set
|
||||
// via Discovery
|
||||
HostFilter HostFilter
|
||||
|
||||
// AddressTranslator will translate addresses found on peer discovery and/or
|
||||
// node change events.
|
||||
AddressTranslator AddressTranslator
|
||||
|
||||
// If IgnorePeerAddr is true and the address in system.peers does not match
|
||||
// the supplied host by either initial hosts or discovered via events then the
|
||||
// host will be replaced with the supplied address.
|
||||
//
|
||||
// For example if an event comes in with host=10.0.0.1 but when looking up that
|
||||
// address in system.local or system.peers returns 127.0.0.1, the peer will be
|
||||
// set to 10.0.0.1 which is what will be used to connect to.
|
||||
IgnorePeerAddr bool
|
||||
|
||||
// If DisableInitialHostLookup then the driver will not attempt to get host info
|
||||
// from the system.peers table, this will mean that the driver will connect to
|
||||
// hosts supplied and will not attempt to lookup the hosts information, this will
|
||||
// mean that data_centre, rack and token information will not be available and as
|
||||
// such host filtering and token aware query routing will not be available.
|
||||
DisableInitialHostLookup bool
|
||||
|
||||
// Configure events the driver will register for
|
||||
Events struct {
|
||||
// disable registering for status events (node up/down)
|
||||
DisableNodeStatusEvents bool
|
||||
// disable registering for topology events (node added/removed/moved)
|
||||
DisableTopologyEvents bool
|
||||
// disable registering for schema events (keyspace/table/function removed/created/updated)
|
||||
DisableSchemaEvents bool
|
||||
}
|
||||
|
||||
// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
|
||||
// send skip_metadata for queries, this means that the result will always contain
|
||||
// the metadata to parse the rows and will not reuse the metadata from the prepared
|
||||
// statement.
|
||||
//
|
||||
// See https://issues.apache.org/jira/browse/CASSANDRA-10786
|
||||
// See https://github.com/scylladb/scylladb/issues/20860
|
||||
//
|
||||
// Default: true
|
||||
DisableSkipMetadata bool
|
||||
|
||||
// QueryObserver will set the provided query observer on all queries created from this session.
|
||||
// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
|
||||
QueryObserver QueryObserver
|
||||
|
||||
// BatchObserver will set the provided batch observer on all queries created from this session.
|
||||
// Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver.
|
||||
BatchObserver BatchObserver
|
||||
|
||||
// ConnectObserver will set the provided connect observer on all queries
|
||||
// created from this session.
|
||||
ConnectObserver ConnectObserver
|
||||
|
||||
// FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session.
|
||||
// Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver.
|
||||
FrameHeaderObserver FrameHeaderObserver
|
||||
|
||||
// StreamObserver will be notified of stream state changes.
|
||||
// This can be used to track in-flight protocol requests and responses.
|
||||
StreamObserver StreamObserver
|
||||
|
||||
// Default idempotence for queries
|
||||
DefaultIdempotence bool
|
||||
|
||||
// The time to wait for frames before flushing the frames connection to Cassandra.
|
||||
// Can help reduce syscall overhead by making less calls to write. Set to 0 to
|
||||
// disable.
|
||||
//
|
||||
// (default: 200 microseconds)
|
||||
WriteCoalesceWaitTime time.Duration
|
||||
|
||||
// Dialer will be used to establish all connections created for this Cluster.
|
||||
// If not provided, a default dialer configured with ConnectTimeout will be used.
|
||||
// Dialer is ignored if HostDialer is provided.
|
||||
Dialer Dialer
|
||||
|
||||
// HostDialer will be used to establish all connections for this Cluster.
|
||||
// Unlike Dialer, HostDialer is responsible for setting up the entire connection, including the TLS session.
|
||||
// To support shard-aware port, HostDialer should implement ShardDialer.
|
||||
// If not provided, Dialer will be used instead.
|
||||
HostDialer HostDialer
|
||||
|
||||
// DisableShardAwarePort will prevent the driver from connecting to Scylla's shard-aware port,
|
||||
// even if there are nodes in the cluster that support it.
|
||||
//
|
||||
// It is generally recommended to leave this option turned off because gocql can use
|
||||
// the shard-aware port to make the process of establishing more robust.
|
||||
// However, if you have a cluster with nodes which expose shard-aware port
|
||||
// but the port is unreachable due to network configuration issues, you can use
|
||||
// this option to work around the issue. Set it to true only if you neither can fix
|
||||
// your network nor disable shard-aware port on your nodes.
|
||||
DisableShardAwarePort bool
|
||||
|
||||
// Logger for this ClusterConfig.
|
||||
// If not specified, defaults to the global gocql.Logger.
|
||||
Logger StdLogger
|
||||
|
||||
// The timeout for the requests to the schema tables. (default: 60s)
|
||||
MetadataSchemaRequestTimeout time.Duration
|
||||
|
||||
// internal config for testing
|
||||
disableControlConn bool
|
||||
disableInit bool
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// NewCluster generates a new config for the default cluster implementation.
|
||||
//
|
||||
// The supplied hosts are used to initially connect to the cluster then the rest of
|
||||
// the ring will be automatically discovered. It is recommended to use the value set in
|
||||
// the Cassandra config for broadcast_address or listen_address, an IP address not
|
||||
// a domain name. This is because events from Cassandra will use the configured IP
|
||||
// address, which is used to index connected hosts. If the domain name specified
|
||||
// resolves to more than 1 IP address then the driver may connect multiple times to
|
||||
// the same host, and will not mark the node being down or up from events.
|
||||
func NewCluster(hosts ...string) *ClusterConfig {
|
||||
cfg := &ClusterConfig{
|
||||
Hosts: hosts,
|
||||
CQLVersion: "3.0.0",
|
||||
Timeout: 11 * time.Second,
|
||||
ConnectTimeout: 11 * time.Second,
|
||||
Port: 9042,
|
||||
NumConns: 2,
|
||||
Consistency: Quorum,
|
||||
MaxPreparedStmts: defaultMaxPreparedStmts,
|
||||
MaxRoutingKeyInfo: 1000,
|
||||
PageSize: 5000,
|
||||
DefaultTimestamp: true,
|
||||
DriverName: defaultDriverName,
|
||||
DriverVersion: defaultDriverVersion,
|
||||
MaxWaitSchemaAgreement: 60 * time.Second,
|
||||
ReconnectInterval: 60 * time.Second,
|
||||
ConvictionPolicy: &SimpleConvictionPolicy{},
|
||||
ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second},
|
||||
InitialReconnectionPolicy: &NoReconnectionPolicy{},
|
||||
SocketKeepalive: 15 * time.Second,
|
||||
WriteCoalesceWaitTime: 200 * time.Microsecond,
|
||||
MetadataSchemaRequestTimeout: 60 * time.Second,
|
||||
DisableSkipMetadata: true,
|
||||
WarningsHandlerBuilder: DefaultWarningHandlerBuilder,
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) logger() StdLogger {
|
||||
if cfg.Logger == nil {
|
||||
return Logger
|
||||
}
|
||||
return cfg.Logger
|
||||
}
|
||||
|
||||
// CreateSession initializes the cluster based on this config and returns a
|
||||
// session object that can be used to interact with the database.
|
||||
func (cfg *ClusterConfig) CreateSession() (*Session, error) {
|
||||
return NewSession(*cfg)
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) CreateSessionNonBlocking() (*Session, error) {
|
||||
return NewSessionNonBlocking(*cfg)
|
||||
}
|
||||
|
||||
// translateAddressPort is a helper method that will use the given AddressTranslator
|
||||
// if defined, to translate the given address and port into a possibly new address
|
||||
// and port, If no AddressTranslator or if an error occurs, the given address and
|
||||
// port will be returned.
|
||||
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
|
||||
if cfg.AddressTranslator == nil || len(addr) == 0 {
|
||||
return addr, port
|
||||
}
|
||||
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
|
||||
if gocqlDebug {
|
||||
cfg.logger().Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
|
||||
}
|
||||
return newAddr, newPort
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
|
||||
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) ValidateAndInitSSL() error {
|
||||
if cfg.SslOpts == nil {
|
||||
return nil
|
||||
}
|
||||
actualTLSConfig, err := setupTLSConfig(cfg.SslOpts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize ssl configuration: %s", err.Error())
|
||||
}
|
||||
|
||||
cfg.actualSslOpts.Store(actualTLSConfig)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) getActualTLSConfig() *tls.Config {
|
||||
val, ok := cfg.actualSslOpts.Load().(*tls.Config)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return val.Clone()
|
||||
}
|
||||
|
||||
func (cfg *ClusterConfig) Validate() error {
|
||||
if len(cfg.Hosts) == 0 {
|
||||
return ErrNoHosts
|
||||
}
|
||||
|
||||
if cfg.Authenticator != nil && cfg.AuthProvider != nil {
|
||||
return errors.New("Can't use both Authenticator and AuthProvider in cluster config.")
|
||||
}
|
||||
|
||||
if cfg.InitialReconnectionPolicy == nil {
|
||||
return errors.New("InitialReconnectionPolicy is nil")
|
||||
}
|
||||
|
||||
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
|
||||
return errors.New("InitialReconnectionPolicy.GetMaxRetries returns negative number")
|
||||
}
|
||||
|
||||
if cfg.ReconnectionPolicy == nil {
|
||||
return errors.New("ReconnectionPolicy is nil")
|
||||
}
|
||||
|
||||
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
|
||||
return errors.New("ReconnectionPolicy.GetMaxRetries returns negative number")
|
||||
}
|
||||
|
||||
if cfg.PageSize < 0 {
|
||||
return errors.New("PageSize should be positive number or zero")
|
||||
}
|
||||
|
||||
if cfg.MaxRoutingKeyInfo < 0 {
|
||||
return errors.New("MaxRoutingKeyInfo should be positive number or zero")
|
||||
}
|
||||
|
||||
if cfg.MaxPreparedStmts < 0 {
|
||||
return errors.New("MaxPreparedStmts should be positive number or zero")
|
||||
}
|
||||
|
||||
if cfg.SocketKeepalive < 0 {
|
||||
return errors.New("SocketKeepalive should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.MaxRequestsPerConn < 0 {
|
||||
return errors.New("MaxRequestsPerConn should be positive number or zero")
|
||||
}
|
||||
|
||||
if cfg.NumConns < 0 {
|
||||
return errors.New("NumConns should be positive non-zero number or zero")
|
||||
}
|
||||
|
||||
if cfg.Port <= 0 || cfg.Port > 65535 {
|
||||
return errors.New("Port should be a valid port number: a number between 1 and 65535")
|
||||
}
|
||||
|
||||
if cfg.WriteTimeout < 0 {
|
||||
return errors.New("WriteTimeout should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.Timeout < 0 {
|
||||
return errors.New("Timeout should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.ConnectTimeout < 0 {
|
||||
return errors.New("ConnectTimeout should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.MetadataSchemaRequestTimeout < 0 {
|
||||
return errors.New("MetadataSchemaRequestTimeout should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.WriteCoalesceWaitTime < 0 {
|
||||
return errors.New("WriteCoalesceWaitTime should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.ReconnectInterval < 0 {
|
||||
return errors.New("ReconnectInterval should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.MaxWaitSchemaAgreement < 0 {
|
||||
return errors.New("MaxWaitSchemaAgreement should be positive time.Duration or zero")
|
||||
}
|
||||
|
||||
if cfg.ProtoVersion < 0 {
|
||||
return errors.New("ProtoVersion should be positive number or zero")
|
||||
}
|
||||
|
||||
if !cfg.DisableSkipMetadata {
|
||||
Logger.Println("warning: enabling skipping metadata can lead to unpredictible results when executing query and altering columns involved in the query.")
|
||||
}
|
||||
|
||||
return cfg.ValidateAndInitSSL()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoHosts = errors.New("no hosts provided")
|
||||
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
|
||||
ErrHostQueryFailed = errors.New("unable to populate Hosts")
|
||||
)
|
||||
|
||||
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
|
||||
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||
// Config is nil | true | verify host
|
||||
// Config is nil | false | do not verify host
|
||||
// false | false | verify host
|
||||
// true | false | do not verify host
|
||||
// false | true | verify host
|
||||
// true | true | verify host
|
||||
var tlsConfig *tls.Config
|
||||
if sslOpts.Config == nil {
|
||||
tlsConfig = &tls.Config{
|
||||
InsecureSkipVerify: !sslOpts.EnableHostVerification,
|
||||
}
|
||||
} else {
|
||||
// use clone to avoid race.
|
||||
tlsConfig = sslOpts.Config.Clone()
|
||||
}
|
||||
|
||||
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
}
|
||||
|
||||
// ca cert is optional
|
||||
if sslOpts.CaPath != "" {
|
||||
if tlsConfig.RootCAs == nil {
|
||||
tlsConfig.RootCAs = x509.NewCertPool()
|
||||
}
|
||||
|
||||
pem, err := ioutil.ReadFile(sslOpts.CaPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open CA certs: %v", err)
|
||||
}
|
||||
|
||||
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
|
||||
return nil, errors.New("failed parsing or CA certs")
|
||||
}
|
||||
}
|
||||
|
||||
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
|
||||
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to load X509 key pair: %v", err)
|
||||
}
|
||||
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
}
|
29
vendor/github.com/gocql/gocql/compressor.go
generated
vendored
Normal file
29
vendor/github.com/gocql/gocql/compressor.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"github.com/klauspost/compress/s2"
|
||||
)
|
||||
|
||||
type Compressor interface {
|
||||
Name() string
|
||||
Encode(data []byte) ([]byte, error)
|
||||
Decode(data []byte) ([]byte, error)
|
||||
}
|
||||
|
||||
// SnappyCompressor implements the Compressor interface and can be used to
|
||||
// compress incoming and outgoing frames. It uses S2 compression algorithm
|
||||
// that is compatible with snappy and aims for high throughput, which is why
|
||||
// it features concurrent compression for bigger payloads.
|
||||
type SnappyCompressor struct{}
|
||||
|
||||
func (s SnappyCompressor) Name() string {
|
||||
return "snappy"
|
||||
}
|
||||
|
||||
func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
|
||||
return s2.EncodeSnappy(nil, data), nil
|
||||
}
|
||||
|
||||
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
|
||||
return s2.Decode(nil, data)
|
||||
}
|
1964
vendor/github.com/gocql/gocql/conn.go
generated
vendored
Normal file
1964
vendor/github.com/gocql/gocql/conn.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
562
vendor/github.com/gocql/gocql/connectionpool.go
generated
vendored
Normal file
562
vendor/github.com/gocql/gocql/connectionpool.go
generated
vendored
Normal file
@@ -0,0 +1,562 @@
|
||||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gocql/gocql/debounce"
|
||||
)
|
||||
|
||||
// interface to implement to receive the host information
|
||||
type SetHosts interface {
|
||||
SetHosts(hosts []*HostInfo)
|
||||
}
|
||||
|
||||
// interface to implement to receive the partitioner value
|
||||
type SetPartitioner interface {
|
||||
SetPartitioner(partitioner string)
|
||||
}
|
||||
|
||||
// interface to implement to receive the tablets value
|
||||
type SetTablets interface {
|
||||
SetTablets(tablets TabletInfoList)
|
||||
}
|
||||
|
||||
type policyConnPool struct {
|
||||
session *Session
|
||||
|
||||
port int
|
||||
numConns int
|
||||
keyspace string
|
||||
|
||||
mu sync.RWMutex
|
||||
hostConnPools map[string]*hostConnPool
|
||||
}
|
||||
|
||||
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
|
||||
hostDialer := cfg.HostDialer
|
||||
|
||||
if hostDialer == nil {
|
||||
dialer := cfg.Dialer
|
||||
if dialer == nil {
|
||||
d := net.Dialer{
|
||||
Timeout: cfg.ConnectTimeout,
|
||||
}
|
||||
if cfg.SocketKeepalive > 0 {
|
||||
d.KeepAlive = cfg.SocketKeepalive
|
||||
}
|
||||
dialer = &ScyllaShardAwareDialer{d}
|
||||
}
|
||||
|
||||
hostDialer = &scyllaDialer{
|
||||
dialer: dialer,
|
||||
logger: cfg.logger(),
|
||||
tlsConfig: cfg.getActualTLSConfig(),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
return &ConnConfig{
|
||||
ProtoVersion: cfg.ProtoVersion,
|
||||
CQLVersion: cfg.CQLVersion,
|
||||
Timeout: cfg.Timeout,
|
||||
WriteTimeout: cfg.WriteTimeout,
|
||||
ConnectTimeout: cfg.ConnectTimeout,
|
||||
Dialer: cfg.Dialer,
|
||||
HostDialer: hostDialer,
|
||||
Compressor: cfg.Compressor,
|
||||
Authenticator: cfg.Authenticator,
|
||||
AuthProvider: cfg.AuthProvider,
|
||||
Keepalive: cfg.SocketKeepalive,
|
||||
Logger: cfg.logger(),
|
||||
tlsConfig: cfg.getActualTLSConfig(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newPolicyConnPool(session *Session) *policyConnPool {
|
||||
// create the pool
|
||||
pool := &policyConnPool{
|
||||
session: session,
|
||||
port: session.cfg.Port,
|
||||
numConns: session.cfg.NumConns,
|
||||
keyspace: session.cfg.Keyspace,
|
||||
hostConnPools: map[string]*hostConnPool{},
|
||||
}
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
toRemove := make(map[string]struct{})
|
||||
for hostID := range p.hostConnPools {
|
||||
toRemove[hostID] = struct{}{}
|
||||
}
|
||||
|
||||
pools := make(chan *hostConnPool)
|
||||
createCount := 0
|
||||
for _, host := range hosts {
|
||||
if !host.IsUp() {
|
||||
// don't create a connection pool for a down host
|
||||
continue
|
||||
}
|
||||
hostID := host.HostID()
|
||||
if _, exists := p.hostConnPools[hostID]; exists {
|
||||
// still have this host, so don't remove it
|
||||
delete(toRemove, hostID)
|
||||
continue
|
||||
}
|
||||
|
||||
createCount++
|
||||
go func(host *HostInfo) {
|
||||
// create a connection pool for the host
|
||||
pools <- newHostConnPool(
|
||||
p.session,
|
||||
host,
|
||||
p.port,
|
||||
p.numConns,
|
||||
p.keyspace,
|
||||
)
|
||||
}(host)
|
||||
}
|
||||
|
||||
// add created pools
|
||||
for createCount > 0 {
|
||||
pool := <-pools
|
||||
createCount--
|
||||
if pool.Size() > 0 {
|
||||
// add pool only if there a connections available
|
||||
p.hostConnPools[pool.host.HostID()] = pool
|
||||
}
|
||||
}
|
||||
|
||||
for addr := range toRemove {
|
||||
pool := p.hostConnPools[addr]
|
||||
delete(p.hostConnPools, addr)
|
||||
go pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *policyConnPool) InFlight() int {
|
||||
p.mu.RLock()
|
||||
count := 0
|
||||
for _, pool := range p.hostConnPools {
|
||||
count += pool.InFlight()
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (p *policyConnPool) Size() int {
|
||||
p.mu.RLock()
|
||||
count := 0
|
||||
for _, pool := range p.hostConnPools {
|
||||
count += pool.Size()
|
||||
}
|
||||
p.mu.RUnlock()
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
|
||||
hostID := host.HostID()
|
||||
p.mu.RLock()
|
||||
pool, ok = p.hostConnPools[hostID]
|
||||
p.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (p *policyConnPool) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
// close the pools
|
||||
for addr, pool := range p.hostConnPools {
|
||||
delete(p.hostConnPools, addr)
|
||||
pool.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *policyConnPool) addHost(host *HostInfo) {
|
||||
hostID := host.HostID()
|
||||
p.mu.Lock()
|
||||
pool, ok := p.hostConnPools[hostID]
|
||||
if !ok {
|
||||
pool = newHostConnPool(
|
||||
p.session,
|
||||
host,
|
||||
host.Port(), // TODO: if port == 0 use pool.port?
|
||||
p.numConns,
|
||||
p.keyspace,
|
||||
)
|
||||
|
||||
p.hostConnPools[hostID] = pool
|
||||
}
|
||||
p.mu.Unlock()
|
||||
|
||||
pool.fill_debounce()
|
||||
}
|
||||
|
||||
func (p *policyConnPool) removeHost(hostID string) {
|
||||
p.mu.Lock()
|
||||
pool, ok := p.hostConnPools[hostID]
|
||||
if !ok {
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
delete(p.hostConnPools, hostID)
|
||||
p.mu.Unlock()
|
||||
|
||||
go pool.Close()
|
||||
}
|
||||
|
||||
// hostConnPool is a connection pool for a single host.
|
||||
// Connection selection is based on a provided ConnSelectionPolicy
|
||||
type hostConnPool struct {
|
||||
session *Session
|
||||
host *HostInfo
|
||||
size int
|
||||
keyspace string
|
||||
// protection for connPicker, closed, filling
|
||||
mu sync.RWMutex
|
||||
connPicker ConnPicker
|
||||
closed bool
|
||||
filling bool
|
||||
debouncer *debounce.SimpleDebouncer
|
||||
|
||||
logger StdLogger
|
||||
}
|
||||
|
||||
func (h *hostConnPool) String() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
size, _ := h.connPicker.Size()
|
||||
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
|
||||
h.filling, h.closed, size, h.size, h.host)
|
||||
}
|
||||
|
||||
func newHostConnPool(session *Session, host *HostInfo, port, size int,
|
||||
keyspace string) *hostConnPool {
|
||||
|
||||
pool := &hostConnPool{
|
||||
session: session,
|
||||
host: host,
|
||||
size: size,
|
||||
keyspace: keyspace,
|
||||
connPicker: nopConnPicker{},
|
||||
filling: false,
|
||||
closed: false,
|
||||
logger: session.logger,
|
||||
debouncer: debounce.NewSimpleDebouncer(),
|
||||
}
|
||||
|
||||
// the pool is not filled or connected
|
||||
return pool
|
||||
}
|
||||
|
||||
// Pick a connection from this connection pool for the given query.
|
||||
func (pool *hostConnPool) Pick(token Token, qry ExecutableQuery) *Conn {
|
||||
pool.mu.RLock()
|
||||
defer pool.mu.RUnlock()
|
||||
|
||||
if pool.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
size, missing := pool.connPicker.Size()
|
||||
if missing > 0 {
|
||||
// try to fill the pool
|
||||
go pool.fill_debounce()
|
||||
|
||||
if size == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return pool.connPicker.Pick(token, qry)
|
||||
}
|
||||
|
||||
// Size returns the number of connections currently active in the pool
|
||||
func (pool *hostConnPool) Size() int {
|
||||
pool.mu.RLock()
|
||||
defer pool.mu.RUnlock()
|
||||
|
||||
size, _ := pool.connPicker.Size()
|
||||
return size
|
||||
}
|
||||
|
||||
// Size returns the number of connections currently active in the pool
|
||||
func (pool *hostConnPool) InFlight() int {
|
||||
pool.mu.RLock()
|
||||
defer pool.mu.RUnlock()
|
||||
|
||||
size := pool.connPicker.InFlight()
|
||||
return size
|
||||
}
|
||||
|
||||
// Close the connection pool
|
||||
func (pool *hostConnPool) Close() {
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
|
||||
if !pool.closed {
|
||||
pool.connPicker.Close()
|
||||
}
|
||||
pool.closed = true
|
||||
}
|
||||
|
||||
// Fill the connection pool
|
||||
func (pool *hostConnPool) fill() {
|
||||
pool.mu.RLock()
|
||||
// avoid filling a closed pool, or concurrent filling
|
||||
if pool.closed || pool.filling {
|
||||
pool.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// determine the filling work to be done
|
||||
startCount, fillCount := pool.connPicker.Size()
|
||||
|
||||
// avoid filling a full (or overfull) pool
|
||||
if fillCount <= 0 {
|
||||
pool.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// switch from read to write lock
|
||||
pool.mu.RUnlock()
|
||||
pool.mu.Lock()
|
||||
|
||||
startCount, fillCount = pool.connPicker.Size()
|
||||
if pool.closed || pool.filling || fillCount <= 0 {
|
||||
// looks like another goroutine already beat this
|
||||
// goroutine to the filling
|
||||
pool.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// ok fill the pool
|
||||
pool.filling = true
|
||||
|
||||
// allow others to access the pool while filling
|
||||
pool.mu.Unlock()
|
||||
// only this goroutine should make calls to fill/empty the pool at this
|
||||
// point until after this routine or its subordinates calls
|
||||
// fillingStopped
|
||||
|
||||
// fill only the first connection synchronously
|
||||
if startCount == 0 {
|
||||
err := pool.connect()
|
||||
pool.logConnectErr(err)
|
||||
|
||||
if err != nil {
|
||||
// probably unreachable host
|
||||
pool.fillingStopped(err)
|
||||
return
|
||||
}
|
||||
// notify the session that this node is connected
|
||||
go pool.session.handleNodeConnected(pool.host)
|
||||
|
||||
// filled one, let's reload it to see if it has changed
|
||||
pool.mu.RLock()
|
||||
_, fillCount = pool.connPicker.Size()
|
||||
pool.mu.RUnlock()
|
||||
}
|
||||
|
||||
// fill the rest of the pool asynchronously
|
||||
go func() {
|
||||
err := pool.connectMany(fillCount)
|
||||
|
||||
// mark the end of filling
|
||||
pool.fillingStopped(err)
|
||||
|
||||
if err == nil && startCount > 0 {
|
||||
// notify the session that this node is connected again
|
||||
go pool.session.handleNodeConnected(pool.host)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (pool *hostConnPool) fill_debounce() {
|
||||
pool.debouncer.Debounce(pool.fill)
|
||||
}
|
||||
|
||||
func (pool *hostConnPool) logConnectErr(err error) {
|
||||
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
|
||||
// connection refused
|
||||
// these are typical during a node outage so avoid log spam.
|
||||
if gocqlDebug {
|
||||
pool.logger.Printf("unable to dial %q: %v\n", pool.host, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
// unexpected error
|
||||
pool.logger.Printf("error: failed to connect to %q due to error: %v", pool.host, err)
|
||||
}
|
||||
}
|
||||
|
||||
// transition back to a not-filling state.
|
||||
func (pool *hostConnPool) fillingStopped(err error) {
|
||||
if err != nil {
|
||||
if gocqlDebug {
|
||||
pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err)
|
||||
}
|
||||
// wait for some time to avoid back-to-back filling
|
||||
// this provides some time between failed attempts
|
||||
// to fill the pool for the host to recover
|
||||
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
|
||||
}
|
||||
|
||||
pool.mu.Lock()
|
||||
pool.filling = false
|
||||
count, _ := pool.connPicker.Size()
|
||||
host := pool.host
|
||||
port := pool.host.Port()
|
||||
pool.mu.Unlock()
|
||||
|
||||
// if we errored and the size is now zero, make sure the host is marked as down
|
||||
// see https://github.com/gocql/gocql/issues/1614
|
||||
if gocqlDebug {
|
||||
pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count)
|
||||
}
|
||||
if err != nil && count == 0 {
|
||||
if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) {
|
||||
pool.session.handleNodeDown(host.ConnectAddress(), port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connectMany creates new connections concurrent.
|
||||
func (pool *hostConnPool) connectMany(count int) error {
|
||||
if count == 0 {
|
||||
return nil
|
||||
}
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
mu sync.Mutex
|
||||
connectErr error
|
||||
)
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := pool.connect()
|
||||
pool.logConnectErr(err)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
connectErr = err
|
||||
mu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
// wait for all connections are done
|
||||
wg.Wait()
|
||||
|
||||
return connectErr
|
||||
}
|
||||
|
||||
// create a new connection to the host and add it to the pool
|
||||
func (pool *hostConnPool) connect() (err error) {
|
||||
pool.mu.Lock()
|
||||
shardID, nrShards := pool.connPicker.NextShard()
|
||||
pool.mu.Unlock()
|
||||
|
||||
// TODO: provide a more robust connection retry mechanism, we should also
|
||||
// be able to detect hosts that come up by trying to connect to downed ones.
|
||||
// try to connect
|
||||
var conn *Conn
|
||||
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
|
||||
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
|
||||
conn, err = pool.session.connectShard(pool.session.ctx, pool.host, pool, shardID, nrShards)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if opErr, isOpErr := err.(*net.OpError); isOpErr {
|
||||
// if the error is not a temporary error (ex: network unreachable) don't
|
||||
// retry
|
||||
if !opErr.Temporary() {
|
||||
break
|
||||
}
|
||||
}
|
||||
if gocqlDebug {
|
||||
pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n",
|
||||
pool.host.ConnectAddress(), err, reconnectionPolicy)
|
||||
}
|
||||
time.Sleep(reconnectionPolicy.GetInterval(i))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pool.keyspace != "" {
|
||||
// set the keyspace
|
||||
if err = conn.UseKeyspace(pool.keyspace); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// add the Conn to the pool
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
|
||||
if pool.closed {
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// lazily initialize the connPicker when we know the required type
|
||||
pool.initConnPicker(conn)
|
||||
pool.connPicker.Put(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pool *hostConnPool) initConnPicker(conn *Conn) {
|
||||
if _, ok := pool.connPicker.(nopConnPicker); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if conn.isScyllaConn() {
|
||||
pool.connPicker = newScyllaConnPicker(conn)
|
||||
return
|
||||
}
|
||||
|
||||
pool.connPicker = newDefaultConnPicker(pool.size)
|
||||
}
|
||||
|
||||
// handle any error from a Conn
|
||||
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
|
||||
if !closed {
|
||||
// still an open connection, so continue using it
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: track the number of errors per host and detect when a host is dead,
|
||||
// then also have something which can detect when a host comes back.
|
||||
pool.mu.Lock()
|
||||
defer pool.mu.Unlock()
|
||||
|
||||
if pool.closed {
|
||||
// pool closed
|
||||
return
|
||||
}
|
||||
|
||||
if gocqlDebug {
|
||||
pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err)
|
||||
}
|
||||
|
||||
pool.connPicker.Remove(conn)
|
||||
go pool.fill_debounce()
|
||||
}
|
140
vendor/github.com/gocql/gocql/connpicker.go
generated
vendored
Normal file
140
vendor/github.com/gocql/gocql/connpicker.go
generated
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type ConnPicker interface {
|
||||
Pick(Token, ExecutableQuery) *Conn
|
||||
Put(*Conn)
|
||||
Remove(conn *Conn)
|
||||
InFlight() int
|
||||
Size() (int, int)
|
||||
Close()
|
||||
|
||||
// NextShard returns the shardID to connect to.
|
||||
// nrShard specifies how many shards the host has.
|
||||
// If nrShards is zero, the caller shouldn't use shard-aware port.
|
||||
NextShard() (shardID, nrShards int)
|
||||
}
|
||||
|
||||
type defaultConnPicker struct {
|
||||
conns []*Conn
|
||||
pos uint32
|
||||
size int
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func newDefaultConnPicker(size int) *defaultConnPicker {
|
||||
if size <= 0 {
|
||||
panic(fmt.Sprintf("invalid pool size %d", size))
|
||||
}
|
||||
return &defaultConnPicker{
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) Remove(conn *Conn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for i, candidate := range p.conns {
|
||||
if candidate == conn {
|
||||
last := len(p.conns) - 1
|
||||
p.conns[i], p.conns = p.conns[last], p.conns[:last]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) Close() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
conns := p.conns
|
||||
p.conns = nil
|
||||
for _, conn := range conns {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) InFlight() int {
|
||||
size := len(p.conns)
|
||||
return size
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) Size() (int, int) {
|
||||
size := len(p.conns)
|
||||
return size, p.size - size
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) Pick(Token, ExecutableQuery) *Conn {
|
||||
pos := int(atomic.AddUint32(&p.pos, 1) - 1)
|
||||
size := len(p.conns)
|
||||
|
||||
var (
|
||||
leastBusyConn *Conn
|
||||
streamsAvailable int
|
||||
)
|
||||
|
||||
// find the conn which has the most available streams, this is racy
|
||||
for i := 0; i < size; i++ {
|
||||
conn := p.conns[(pos+i)%size]
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
if streams := conn.AvailableStreams(); streams > streamsAvailable {
|
||||
leastBusyConn = conn
|
||||
streamsAvailable = streams
|
||||
}
|
||||
}
|
||||
|
||||
return leastBusyConn
|
||||
}
|
||||
|
||||
func (p *defaultConnPicker) Put(conn *Conn) {
|
||||
p.mu.Lock()
|
||||
p.conns = append(p.conns, conn)
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (*defaultConnPicker) NextShard() (shardID, nrShards int) {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// nopConnPicker is a no-operation implementation of ConnPicker, it's used when
|
||||
// hostConnPool is created to allow deferring creation of the actual ConnPicker
|
||||
// to the point where we have first connection.
|
||||
type nopConnPicker struct{}
|
||||
|
||||
func (nopConnPicker) Pick(Token, ExecutableQuery) *Conn {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (nopConnPicker) Put(*Conn) {
|
||||
}
|
||||
|
||||
func (nopConnPicker) Remove(conn *Conn) {
|
||||
}
|
||||
|
||||
func (nopConnPicker) InFlight() int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (nopConnPicker) Size() (int, int) {
|
||||
// Return 1 to make hostConnPool to try to establish a connection.
|
||||
// When first connection is established hostConnPool replaces nopConnPicker
|
||||
// with a different ConnPicker implementation.
|
||||
return 0, 1
|
||||
}
|
||||
|
||||
func (nopConnPicker) Close() {
|
||||
}
|
||||
|
||||
func (nopConnPicker) NextShard() (shardID, nrShards int) {
|
||||
return 0, 0
|
||||
}
|
517
vendor/github.com/gocql/gocql/control.go
generated
vendored
Normal file
517
vendor/github.com/gocql/gocql/control.go
generated
vendored
Normal file
@@ -0,0 +1,517 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
randr *rand.Rand
|
||||
mutRandr sync.Mutex
|
||||
)
|
||||
|
||||
func init() {
|
||||
b := make([]byte, 4)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
panic(fmt.Sprintf("unable to seed random number generator: %v", err))
|
||||
}
|
||||
|
||||
randr = rand.New(rand.NewSource(int64(readInt(b))))
|
||||
}
|
||||
|
||||
const (
|
||||
controlConnStarting = 0
|
||||
controlConnStarted = 1
|
||||
controlConnClosing = -1
|
||||
)
|
||||
|
||||
type controlConnection interface {
|
||||
getConn() *connHost
|
||||
awaitSchemaAgreement() error
|
||||
query(statement string, values ...interface{}) (iter *Iter)
|
||||
discoverProtocol(hosts []*HostInfo) (int, error)
|
||||
connect(hosts []*HostInfo) error
|
||||
close()
|
||||
getSession() *Session
|
||||
}
|
||||
|
||||
// Ensure that the atomic variable is aligned to a 64bit boundary
|
||||
// so that atomic operations can be applied on 32bit architectures.
|
||||
type controlConn struct {
|
||||
state int32
|
||||
reconnecting int32
|
||||
|
||||
session *Session
|
||||
conn atomic.Value
|
||||
|
||||
retry RetryPolicy
|
||||
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
func (c *controlConn) getSession() *Session {
|
||||
return c.session
|
||||
}
|
||||
|
||||
func createControlConn(session *Session) *controlConn {
|
||||
|
||||
control := &controlConn{
|
||||
session: session,
|
||||
quit: make(chan struct{}),
|
||||
retry: &SimpleRetryPolicy{NumRetries: 3},
|
||||
}
|
||||
|
||||
control.conn.Store((*connHost)(nil))
|
||||
|
||||
return control
|
||||
}
|
||||
|
||||
func (c *controlConn) heartBeat() {
|
||||
if !atomic.CompareAndSwapInt32(&c.state, controlConnStarting, controlConnStarted) {
|
||||
return
|
||||
}
|
||||
|
||||
sleepTime := 1 * time.Second
|
||||
timer := time.NewTimer(sleepTime)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
timer.Reset(sleepTime)
|
||||
|
||||
select {
|
||||
case <-c.quit:
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
resp, err := c.writeFrame(&writeOptionsFrame{})
|
||||
if err != nil {
|
||||
goto reconn
|
||||
}
|
||||
|
||||
switch resp.(type) {
|
||||
case *supportedFrame:
|
||||
// Everything ok
|
||||
sleepTime = 30 * time.Second
|
||||
continue
|
||||
case error:
|
||||
goto reconn
|
||||
default:
|
||||
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
|
||||
}
|
||||
|
||||
reconn:
|
||||
// try to connect a bit faster
|
||||
sleepTime = 1 * time.Second
|
||||
c.reconnect()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
|
||||
|
||||
func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
|
||||
var port int
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
port = defaultPort
|
||||
} else {
|
||||
port, err = strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var hosts []*HostInfo
|
||||
|
||||
// Check if host is a literal IP address
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
// Look up host in DNS
|
||||
ips, err := LookupIP(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
|
||||
}
|
||||
|
||||
// Filter to v4 addresses if any present
|
||||
if hostLookupPreferV4 {
|
||||
var preferredIPs []net.IP
|
||||
for _, v := range ips {
|
||||
if v4 := v.To4(); v4 != nil {
|
||||
preferredIPs = append(preferredIPs, v4)
|
||||
}
|
||||
}
|
||||
if len(preferredIPs) != 0 {
|
||||
ips = preferredIPs
|
||||
}
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
|
||||
}
|
||||
|
||||
return hosts, nil
|
||||
}
|
||||
|
||||
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
|
||||
shuffled := make([]*HostInfo, len(hosts))
|
||||
copy(shuffled, hosts)
|
||||
|
||||
mutRandr.Lock()
|
||||
randr.Shuffle(len(hosts), func(i, j int) {
|
||||
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
||||
})
|
||||
mutRandr.Unlock()
|
||||
|
||||
return shuffled
|
||||
}
|
||||
|
||||
// this is going to be version dependant and a nightmare to maintain :(
|
||||
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
|
||||
|
||||
func parseProtocolFromError(err error) int {
|
||||
// I really wish this had the actual info in the error frame...
|
||||
matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
|
||||
if len(matches) != 1 || len(matches[0]) != 2 {
|
||||
if verr, ok := err.(*protocolError); ok {
|
||||
return int(verr.frame.Header().version.version())
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
max, err := strconv.Atoi(matches[0][1])
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return max
|
||||
}
|
||||
|
||||
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
|
||||
hosts = shuffleHosts(hosts)
|
||||
|
||||
connCfg := *c.session.connCfg
|
||||
connCfg.ProtoVersion = 4 // TODO: define maxProtocol
|
||||
|
||||
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
|
||||
// we should never get here, but if we do it means we connected to a
|
||||
// host successfully which means our attempted protocol version worked
|
||||
if !closed {
|
||||
c.Close()
|
||||
}
|
||||
})
|
||||
|
||||
var err error
|
||||
for _, host := range hosts {
|
||||
var conn *Conn
|
||||
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return connCfg.ProtoVersion, nil
|
||||
}
|
||||
|
||||
if proto := parseProtocolFromError(err); proto > 0 {
|
||||
return proto, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, err
|
||||
}
|
||||
|
||||
func (c *controlConn) connect(hosts []*HostInfo) error {
|
||||
if len(hosts) == 0 {
|
||||
return errors.New("control: no endpoints specified")
|
||||
}
|
||||
|
||||
// shuffle endpoints so not all drivers will connect to the same initial
|
||||
// node.
|
||||
hosts = shuffleHosts(hosts)
|
||||
|
||||
cfg := *c.session.connCfg
|
||||
cfg.disableCoalesce = true
|
||||
|
||||
var conn *Conn
|
||||
var err error
|
||||
for _, host := range hosts {
|
||||
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
|
||||
if err != nil {
|
||||
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
continue
|
||||
}
|
||||
err = c.setupConn(conn)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
if conn == nil {
|
||||
return fmt.Errorf("unable to connect to initial hosts: %v", err)
|
||||
}
|
||||
|
||||
// we could fetch the initial ring here and update initial host data. So that
|
||||
// when we return from here we have a ring topology ready to go.
|
||||
|
||||
go c.heartBeat()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type connHost struct {
|
||||
conn ConnInterface
|
||||
host *HostInfo
|
||||
}
|
||||
|
||||
func (c *controlConn) setupConn(conn *Conn) error {
|
||||
// we need up-to-date host info for the filterHost call below
|
||||
iter := conn.querySystem(context.TODO(), qrySystemLocal)
|
||||
defaultPort := 9042
|
||||
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
|
||||
defaultPort = tcpAddr.Port
|
||||
}
|
||||
host, err := hostInfoFromIter(iter, conn.host.connectAddress, defaultPort, c.session.cfg.translateAddressPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
host = c.session.hostSource.addOrUpdate(host)
|
||||
|
||||
if c.session.cfg.filterHost(host) {
|
||||
return fmt.Errorf("host was filtered: %v", host.ConnectAddress())
|
||||
}
|
||||
|
||||
if err := c.registerEvents(conn); err != nil {
|
||||
return fmt.Errorf("register events: %v", err)
|
||||
}
|
||||
|
||||
ch := &connHost{
|
||||
conn: conn,
|
||||
host: host,
|
||||
}
|
||||
|
||||
c.conn.Store(ch)
|
||||
if c.session.initialized() {
|
||||
// We connected to control conn, so add the connect the host in pool as well.
|
||||
// Notify session we can start trying to connect to the node.
|
||||
// We can't start the fill before the session is initialized, otherwise the fill would interfere
|
||||
// with the fill called by Session.init. Session.init needs to wait for its fill to finish and that
|
||||
// would return immediately if we started the fill here.
|
||||
// TODO(martin-sucha): Trigger pool refill for all hosts, like in reconnectDownedHosts?
|
||||
go c.session.startPoolFill(host)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *controlConn) registerEvents(conn *Conn) error {
|
||||
var events []string
|
||||
|
||||
if !c.session.cfg.Events.DisableTopologyEvents {
|
||||
events = append(events, "TOPOLOGY_CHANGE")
|
||||
}
|
||||
if !c.session.cfg.Events.DisableNodeStatusEvents {
|
||||
events = append(events, "STATUS_CHANGE")
|
||||
}
|
||||
if !c.session.cfg.Events.DisableSchemaEvents {
|
||||
events = append(events, "SCHEMA_CHANGE")
|
||||
}
|
||||
|
||||
if len(events) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
framer, err := conn.exec(context.Background(),
|
||||
&writeRegisterFrame{
|
||||
events: events,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
frame, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
return err
|
||||
} else if _, ok := frame.(*readyFrame); !ok {
|
||||
return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *controlConn) reconnect() {
|
||||
if atomic.LoadInt32(&c.state) == controlConnClosing {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&c.reconnecting, 0)
|
||||
|
||||
conn, err := c.attemptReconnect()
|
||||
|
||||
if conn == nil {
|
||||
c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = c.session.refreshRingNow()
|
||||
if err != nil {
|
||||
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
|
||||
}
|
||||
|
||||
err = c.session.metadataDescriber.refreshAllSchema()
|
||||
if err != nil {
|
||||
c.session.logger.Printf("gocql: unable to refresh the schema: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *controlConn) attemptReconnect() (*Conn, error) {
|
||||
hosts := c.session.hostSource.getHostsList()
|
||||
hosts = shuffleHosts(hosts)
|
||||
|
||||
// keep the old behavior of connecting to the old host first by moving it to
|
||||
// the front of the slice
|
||||
ch := c.getConn()
|
||||
if ch != nil {
|
||||
for i := range hosts {
|
||||
if hosts[i].Equal(ch.host) {
|
||||
hosts[0], hosts[i] = hosts[i], hosts[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
ch.conn.Close()
|
||||
}
|
||||
|
||||
conn, err := c.attemptReconnectToAnyOfHosts(hosts)
|
||||
|
||||
if conn != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err)
|
||||
c.session.logger.Printf("gocql: control falling back to initial contact points.\n")
|
||||
// Fallback to initial contact points, as it may be the case that all known initialHosts
|
||||
// changed their IPs while keeping the same hostname(s).
|
||||
initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
|
||||
if resolvErr != nil {
|
||||
return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr)
|
||||
}
|
||||
|
||||
return c.attemptReconnectToAnyOfHosts(initialHosts)
|
||||
}
|
||||
|
||||
func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) {
|
||||
var conn *Conn
|
||||
var err error
|
||||
for _, host := range hosts {
|
||||
conn, err = c.session.connect(c.session.ctx, host, c)
|
||||
if err != nil {
|
||||
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
|
||||
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
|
||||
}
|
||||
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
continue
|
||||
}
|
||||
err = c.setupConn(conn)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
|
||||
if !closed {
|
||||
return
|
||||
}
|
||||
|
||||
oldConn := c.getConn()
|
||||
|
||||
// If connection has long gone, and not been attempted for awhile,
|
||||
// it's possible to have oldConn as nil here (#1297).
|
||||
if oldConn != nil && oldConn.conn != conn {
|
||||
return
|
||||
}
|
||||
|
||||
c.reconnect()
|
||||
}
|
||||
|
||||
func (c *controlConn) getConn() *connHost {
|
||||
return c.conn.Load().(*connHost)
|
||||
}
|
||||
|
||||
func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
|
||||
ch := c.getConn()
|
||||
if ch == nil {
|
||||
return nil, errNoControl
|
||||
}
|
||||
|
||||
framer, err := ch.conn.exec(context.Background(), w, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return framer.parseFrame()
|
||||
}
|
||||
|
||||
// query will return nil if the connection is closed or nil
|
||||
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
|
||||
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
|
||||
|
||||
for {
|
||||
ch := c.getConn()
|
||||
q.conn = ch.conn.(*Conn)
|
||||
iter = ch.conn.executeQuery(context.TODO(), q)
|
||||
|
||||
if gocqlDebug && iter.err != nil {
|
||||
c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err)
|
||||
}
|
||||
|
||||
q.AddAttempts(1, c.getConn().host)
|
||||
if iter.err == nil || !c.retry.Attempt(q) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *controlConn) awaitSchemaAgreement() error {
|
||||
ch := c.getConn()
|
||||
return (&Iter{err: ch.conn.awaitSchemaAgreement(context.TODO())}).err
|
||||
}
|
||||
|
||||
func (c *controlConn) close() {
|
||||
if atomic.CompareAndSwapInt32(&c.state, controlConnStarted, controlConnClosing) {
|
||||
c.quit <- struct{}{}
|
||||
}
|
||||
|
||||
ch := c.getConn()
|
||||
if ch != nil {
|
||||
ch.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
var errNoControl = errors.New("gocql: no control connection available")
|
11
vendor/github.com/gocql/gocql/cqltypes.go
generated
vendored
Normal file
11
vendor/github.com/gocql/gocql/cqltypes.go
generated
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocql
|
||||
|
||||
type Duration struct {
|
||||
Months int32
|
||||
Days int32
|
||||
Nanoseconds int64
|
||||
}
|
164
vendor/github.com/gocql/gocql/debounce/refresh_deboucer.go
generated
vendored
Normal file
164
vendor/github.com/gocql/gocql/debounce/refresh_deboucer.go
generated
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package debounce
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
RingRefreshDebounceTime = 1 * time.Second
|
||||
)
|
||||
|
||||
// debounces requests to call a refresh function (currently used for ring refresh). It also supports triggering a refresh immediately.
|
||||
type RefreshDebouncer struct {
|
||||
mu sync.Mutex
|
||||
stopped bool
|
||||
broadcaster *errorBroadcaster
|
||||
interval time.Duration
|
||||
timer *time.Timer
|
||||
refreshNowCh chan struct{}
|
||||
quit chan struct{}
|
||||
refreshFn func() error
|
||||
}
|
||||
|
||||
func NewRefreshDebouncer(interval time.Duration, refreshFn func() error) *RefreshDebouncer {
|
||||
d := &RefreshDebouncer{
|
||||
stopped: false,
|
||||
broadcaster: nil,
|
||||
refreshNowCh: make(chan struct{}, 1),
|
||||
quit: make(chan struct{}),
|
||||
interval: interval,
|
||||
timer: time.NewTimer(interval),
|
||||
refreshFn: refreshFn,
|
||||
}
|
||||
d.timer.Stop()
|
||||
go d.flusher()
|
||||
return d
|
||||
}
|
||||
|
||||
// debounces a request to call the refresh function
|
||||
func (d *RefreshDebouncer) Debounce() {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
if d.stopped {
|
||||
return
|
||||
}
|
||||
d.timer.Reset(d.interval)
|
||||
}
|
||||
|
||||
// requests an immediate refresh which will cancel pending refresh requests
|
||||
func (d *RefreshDebouncer) RefreshNow() <-chan error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
if d.broadcaster == nil {
|
||||
d.broadcaster = newErrorBroadcaster()
|
||||
select {
|
||||
case d.refreshNowCh <- struct{}{}:
|
||||
default:
|
||||
// already a refresh pending
|
||||
}
|
||||
}
|
||||
return d.broadcaster.newListener()
|
||||
}
|
||||
|
||||
func (d *RefreshDebouncer) flusher() {
|
||||
for {
|
||||
select {
|
||||
case <-d.refreshNowCh:
|
||||
case <-d.timer.C:
|
||||
case <-d.quit:
|
||||
}
|
||||
d.mu.Lock()
|
||||
if d.stopped {
|
||||
if d.broadcaster != nil {
|
||||
d.broadcaster.stop()
|
||||
d.broadcaster = nil
|
||||
}
|
||||
d.timer.Stop()
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// make sure both request channels are cleared before we refresh
|
||||
select {
|
||||
case <-d.refreshNowCh:
|
||||
default:
|
||||
}
|
||||
|
||||
d.timer.Stop()
|
||||
select {
|
||||
case <-d.timer.C:
|
||||
default:
|
||||
}
|
||||
|
||||
curBroadcaster := d.broadcaster
|
||||
d.broadcaster = nil
|
||||
d.mu.Unlock()
|
||||
|
||||
err := d.refreshFn()
|
||||
if curBroadcaster != nil {
|
||||
curBroadcaster.broadcast(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *RefreshDebouncer) Stop() {
|
||||
d.mu.Lock()
|
||||
if d.stopped {
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
d.stopped = true
|
||||
d.mu.Unlock()
|
||||
d.quit <- struct{}{} // sync with flusher
|
||||
close(d.quit)
|
||||
}
|
||||
|
||||
// broadcasts an error to multiple channels (listeners)
|
||||
type errorBroadcaster struct {
|
||||
listeners []chan<- error
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newErrorBroadcaster() *errorBroadcaster {
|
||||
return &errorBroadcaster{
|
||||
listeners: nil,
|
||||
mu: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *errorBroadcaster) newListener() <-chan error {
|
||||
ch := make(chan error, 1)
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
b.listeners = append(b.listeners, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (b *errorBroadcaster) broadcast(err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
curListeners := b.listeners
|
||||
if len(curListeners) > 0 {
|
||||
b.listeners = nil
|
||||
} else {
|
||||
return
|
||||
}
|
||||
|
||||
for _, listener := range curListeners {
|
||||
listener <- err
|
||||
close(listener)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *errorBroadcaster) stop() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if len(b.listeners) == 0 {
|
||||
return
|
||||
}
|
||||
for _, listener := range b.listeners {
|
||||
close(listener)
|
||||
}
|
||||
b.listeners = nil
|
||||
}
|
34
vendor/github.com/gocql/gocql/debounce/simple_debouncer.go
generated
vendored
Normal file
34
vendor/github.com/gocql/gocql/debounce/simple_debouncer.go
generated
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
package debounce
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// SimpleDebouncer is are tool for queuing immutable functions calls. It provides:
|
||||
// 1. Blocking simultaneous calls
|
||||
// 2. If there is no running call and no waiting call, then the current call go through
|
||||
// 3. If there is running call and no waiting call, then the current call go waiting
|
||||
// 4. If there is running call and waiting call, then the current call are voided
|
||||
type SimpleDebouncer struct {
|
||||
m sync.Mutex
|
||||
count atomic.Int32
|
||||
}
|
||||
|
||||
// NewSimpleDebouncer creates a new SimpleDebouncer.
|
||||
func NewSimpleDebouncer() *SimpleDebouncer {
|
||||
return &SimpleDebouncer{}
|
||||
}
|
||||
|
||||
// Debounce attempts to execute the function if the logic of the SimpleDebouncer allows it.
|
||||
func (d *SimpleDebouncer) Debounce(fn func()) bool {
|
||||
if d.count.Add(1) > 2 {
|
||||
d.count.Add(-1)
|
||||
return false
|
||||
}
|
||||
d.m.Lock()
|
||||
fn()
|
||||
d.count.Add(-1)
|
||||
d.m.Unlock()
|
||||
return true
|
||||
}
|
6
vendor/github.com/gocql/gocql/debug_off.go
generated
vendored
Normal file
6
vendor/github.com/gocql/gocql/debug_off.go
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build !gocql_debug
|
||||
// +build !gocql_debug
|
||||
|
||||
package gocql
|
||||
|
||||
const gocqlDebug = false
|
6
vendor/github.com/gocql/gocql/debug_on.go
generated
vendored
Normal file
6
vendor/github.com/gocql/gocql/debug_on.go
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
//go:build gocql_debug
|
||||
// +build gocql_debug
|
||||
|
||||
package gocql
|
||||
|
||||
const gocqlDebug = true
|
91
vendor/github.com/gocql/gocql/dial.go
generated
vendored
Normal file
91
vendor/github.com/gocql/gocql/dial.go
generated
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HostDialer allows customizing connection to cluster nodes.
|
||||
type HostDialer interface {
|
||||
// DialHost establishes a connection to the host.
|
||||
// The returned connection must be directly usable for CQL protocol,
|
||||
// specifically DialHost is responsible also for setting up the TLS session if needed.
|
||||
// DialHost should disable write coalescing if the returned net.Conn does not support writev.
|
||||
// As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing.
|
||||
// You can use WrapTLS helper function if you don't need to override the TLS setup.
|
||||
DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error)
|
||||
}
|
||||
|
||||
// DialedHost contains information about established connection to a host.
|
||||
type DialedHost struct {
|
||||
// Conn used to communicate with the server.
|
||||
Conn net.Conn
|
||||
|
||||
// DisableCoalesce disables write coalescing for the Conn.
|
||||
// If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0.
|
||||
DisableCoalesce bool
|
||||
}
|
||||
|
||||
// defaultHostDialer dials host in a default way.
|
||||
type defaultHostDialer struct {
|
||||
dialer Dialer
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
|
||||
ip := host.ConnectAddress()
|
||||
port := host.Port()
|
||||
|
||||
if !validIpAddr(ip) {
|
||||
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
|
||||
} else if port == 0 {
|
||||
return nil, fmt.Errorf("host missing port: %v", port)
|
||||
}
|
||||
|
||||
connAddr := host.ConnectAddressAndPort()
|
||||
conn, err := hd.dialer.DialContext(ctx, "tcp", connAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addr := host.HostnameAndPort()
|
||||
return WrapTLS(ctx, conn, addr, hd.tlsConfig)
|
||||
}
|
||||
|
||||
func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config {
|
||||
// the TLS config is safe to be reused by connections but it must not
|
||||
// be modified after being used.
|
||||
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
|
||||
colonPos := strings.LastIndex(addr, ":")
|
||||
if colonPos == -1 {
|
||||
colonPos = len(addr)
|
||||
}
|
||||
hostname := addr[:colonPos]
|
||||
// clone config to avoid modifying the shared one.
|
||||
tlsConfig = tlsConfig.Clone()
|
||||
tlsConfig.ServerName = hostname
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig.
|
||||
// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure.
|
||||
// If the tlsConfig does not have server name set, it is updated based on the default gocql rules.
|
||||
func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) {
|
||||
if tlsConfig != nil {
|
||||
tlsConfig := tlsConfigForAddr(tlsConfig, addr)
|
||||
tconn := tls.Client(conn, tlsConfig)
|
||||
if err := tconn.HandshakeContext(ctx); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
conn = tconn
|
||||
}
|
||||
|
||||
return &DialedHost{
|
||||
Conn: conn,
|
||||
DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped.
|
||||
}, nil
|
||||
}
|
375
vendor/github.com/gocql/gocql/doc.go
generated
vendored
Normal file
375
vendor/github.com/gocql/gocql/doc.go
generated
vendored
Normal file
@@ -0,0 +1,375 @@
|
||||
// Copyright (c) 2012-2015 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package gocql implements a fast and robust Cassandra driver for the
|
||||
// Go programming language.
|
||||
//
|
||||
// # Connecting to the cluster
|
||||
//
|
||||
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
//
|
||||
// Port can be specified as part of the address, the above is equivalent to:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
|
||||
//
|
||||
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
|
||||
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
|
||||
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
|
||||
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
|
||||
//
|
||||
// Then you can customize more options (see ClusterConfig):
|
||||
//
|
||||
// cluster.Keyspace = "example"
|
||||
// cluster.Consistency = gocql.Quorum
|
||||
// cluster.ProtoVersion = 4
|
||||
//
|
||||
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
|
||||
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
|
||||
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
|
||||
//
|
||||
// The driver advertises the module name and version in the STARTUP message, so servers are able to detect the version.
|
||||
// If you use replace directive in go.mod, the driver will send information about the replacement module instead.
|
||||
//
|
||||
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
|
||||
//
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// # Authentication
|
||||
//
|
||||
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
|
||||
// client response pairs. The details of the exchanged messages depend on the authenticator used.
|
||||
//
|
||||
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
|
||||
//
|
||||
// PasswordAuthenticator is provided to use for username/password authentication:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.Authenticator = gocql.PasswordAuthenticator{
|
||||
// Username: "user",
|
||||
// Password: "password"
|
||||
// }
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// By default, PasswordAuthenticator will attempt to authenticate regardless of what implementation the server returns
|
||||
// in its AUTHENTICATE message as its authenticator, (e.g. org.apache.cassandra.auth.PasswordAuthenticator). If you
|
||||
// wish to restrict this you may use PasswordAuthenticator.AllowedAuthenticators:
|
||||
//
|
||||
// cluster.Authenticator = gocql.PasswordAuthenticator {
|
||||
// Username: "user",
|
||||
// Password: "password"
|
||||
// AllowedAuthenticators: []string{"org.apache.cassandra.auth.PasswordAuthenticator"},
|
||||
// }
|
||||
//
|
||||
// # Transport layer security
|
||||
//
|
||||
// It is possible to secure traffic between the client and server with TLS.
|
||||
//
|
||||
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
|
||||
// There are also helpers to load keys/certificates from files.
|
||||
//
|
||||
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
|
||||
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
|
||||
// SslOptions and Config.InsecureSkipVerify interact as follows:
|
||||
//
|
||||
// Config.InsecureSkipVerify | EnableHostVerification | Result
|
||||
// Config is nil | false | do not verify host
|
||||
// Config is nil | true | verify host
|
||||
// false | false | verify host
|
||||
// true | false | do not verify host
|
||||
// false | true | verify host
|
||||
// true | true | verify host
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.SslOpts = &gocql.SslOptions{
|
||||
// EnableHostVerification: true,
|
||||
// }
|
||||
// session, err := cluster.CreateSession()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer session.Close()
|
||||
//
|
||||
// # Data-center awareness and query routing
|
||||
//
|
||||
// To route queries to local DC first, use DCAwareRoundRobinPolicy. For example, if the datacenter you
|
||||
// want to primarily connect is called dc1 (as configured in the database):
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1")
|
||||
//
|
||||
// The driver can route queries to nodes that hold data replicas based on partition key (preferring local DC).
|
||||
//
|
||||
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
|
||||
// cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1"))
|
||||
//
|
||||
// Note that TokenAwareHostPolicy can take options such as gocql.ShuffleReplicas and gocql.NonLocalReplicasFallback.
|
||||
//
|
||||
// We recommend running with a token aware host policy in production for maximum performance.
|
||||
//
|
||||
// The driver can only use token-aware routing for queries where all partition key columns are query parameters.
|
||||
// For example, instead of
|
||||
//
|
||||
// session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def")
|
||||
//
|
||||
// use
|
||||
//
|
||||
// session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def")
|
||||
//
|
||||
// # Rack-level awareness
|
||||
//
|
||||
// The DCAwareRoundRobinPolicy can be replaced with RackAwareRoundRobinPolicy, which takes two parameters, datacenter and rack.
|
||||
//
|
||||
// Instead of dividing hosts with two tiers (local datacenter and remote datacenters) it divides hosts into three
|
||||
// (the local rack, the rest of the local datacenter, and everything else).
|
||||
//
|
||||
// RackAwareRoundRobinPolicy can be combined with TokenAwareHostPolicy in the same way as DCAwareRoundRobinPolicy.
|
||||
//
|
||||
// # Executing queries
|
||||
//
|
||||
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
|
||||
// modified after starting execution of the query.
|
||||
//
|
||||
// To execute a query without reading results, use Query.Exec:
|
||||
//
|
||||
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
|
||||
//
|
||||
// Single row can be read by calling Query.Scan:
|
||||
//
|
||||
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
|
||||
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
|
||||
//
|
||||
// Multiple rows can be read using Iter.Scanner:
|
||||
//
|
||||
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
|
||||
// "me").WithContext(ctx).Iter().Scanner()
|
||||
// for scanner.Next() {
|
||||
// var (
|
||||
// id gocql.UUID
|
||||
// text string
|
||||
// )
|
||||
// err = scanner.Scan(&id, &text)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// fmt.Println("Tweet:", id, text)
|
||||
// }
|
||||
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
|
||||
// if err := scanner.Err(); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// See Example for complete example.
|
||||
//
|
||||
// # Prepared statements
|
||||
//
|
||||
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
|
||||
// of prepared statements.
|
||||
// CQL protocol does not support preparing other query types.
|
||||
//
|
||||
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
|
||||
// This will cause the database to ignore writing the column.
|
||||
// The main advantage is the ability to keep the same prepared statement even when you don't
|
||||
// want to update some fields, where before you needed to make another prepared statement.
|
||||
//
|
||||
// # Executing multiple queries concurrently
|
||||
//
|
||||
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
|
||||
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
|
||||
// are executed asynchronously at the protocol level.
|
||||
//
|
||||
// results := make(chan error, 2)
|
||||
// go func() {
|
||||
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world 1").Exec()
|
||||
// }()
|
||||
// go func() {
|
||||
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
|
||||
// "me", gocql.TimeUUID(), "hello world 2").Exec()
|
||||
// }()
|
||||
//
|
||||
// # Nulls
|
||||
//
|
||||
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
|
||||
// column being null and empty string, you can unmarshal into *string variable instead of string.
|
||||
//
|
||||
// var text *string
|
||||
// err := scanner.Scan(&text)
|
||||
// if err != nil {
|
||||
// // handle error
|
||||
// }
|
||||
// if text != nil {
|
||||
// // not null
|
||||
// }
|
||||
// else {
|
||||
// // null
|
||||
// }
|
||||
//
|
||||
// See Example_nulls for full example.
|
||||
//
|
||||
// # Reusing slices
|
||||
//
|
||||
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
|
||||
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
|
||||
// memory structures.
|
||||
//
|
||||
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||
// var myInts []int
|
||||
// for scanner.Next() {
|
||||
// // This scan reuses backing store of myInts for each row.
|
||||
// err = scanner.Scan(&myInts)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
|
||||
// slice variable within the scanner loop:
|
||||
//
|
||||
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
|
||||
// for scanner.Next() {
|
||||
// var myInts []int
|
||||
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
|
||||
// err = scanner.Scan(&myInts)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// # Paging
|
||||
//
|
||||
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
|
||||
// Query.PageSize, and Query.Prefetch.
|
||||
//
|
||||
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
|
||||
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
|
||||
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
|
||||
// it contains data from primary keys.
|
||||
//
|
||||
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
|
||||
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
|
||||
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
|
||||
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
|
||||
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
|
||||
//
|
||||
// The driver does not check whether the paging state is from the same protocol version/statement.
|
||||
// You might want to validate yourself as this could be a problem if you store paging state externally.
|
||||
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
|
||||
//
|
||||
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
|
||||
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
|
||||
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
|
||||
//
|
||||
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
|
||||
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
|
||||
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
|
||||
// rely on the page having the exact count of items.
|
||||
//
|
||||
// See Example_paging for an example of manual paging.
|
||||
//
|
||||
// # Dynamic list of columns
|
||||
//
|
||||
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
|
||||
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
|
||||
//
|
||||
// See Example_dynamicColumns.
|
||||
//
|
||||
// # Batches
|
||||
//
|
||||
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
|
||||
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
|
||||
// Then execute the batch with Session.ExecuteBatch.
|
||||
//
|
||||
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
|
||||
// overhead to ensure this property.
|
||||
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
|
||||
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
|
||||
// A counter batch can only contain statements to update counters.
|
||||
//
|
||||
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
|
||||
// involve only a single partition).
|
||||
// Multi-partition batch needs to be split by the coordinator node and re-sent to
|
||||
// correct nodes.
|
||||
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
|
||||
// additional network hop.
|
||||
//
|
||||
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
|
||||
// There are differences how those are executed.
|
||||
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
|
||||
// Session.ExecuteBatch prepares individual statements in the batch.
|
||||
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
|
||||
//
|
||||
// See Example_batch for an example.
|
||||
//
|
||||
// # Lightweight transactions
|
||||
//
|
||||
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
|
||||
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
|
||||
//
|
||||
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
|
||||
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
|
||||
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
|
||||
// Session.MapExecuteBatchCAS.
|
||||
//
|
||||
// # Retries and speculative execution
|
||||
//
|
||||
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
|
||||
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
|
||||
// execution.
|
||||
//
|
||||
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
|
||||
// If the query is LWT and the configured RetryPolicy additionally implements LWTRetryPolicy
|
||||
// interface, then the policy will be cast to LWTRetryPolicy and used this way.
|
||||
//
|
||||
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
|
||||
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
|
||||
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
|
||||
// is still executing. The two parallel executions of the query race to return a result, the first received result will
|
||||
// be returned.
|
||||
//
|
||||
// # User-defined types
|
||||
//
|
||||
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
|
||||
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
|
||||
//
|
||||
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
|
||||
//
|
||||
// type MyUDT struct {
|
||||
// FieldA int32 `cql:"a"`
|
||||
// FieldB string `cql:"b"`
|
||||
// }
|
||||
//
|
||||
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
|
||||
//
|
||||
// # Metrics and tracing
|
||||
//
|
||||
// It is possible to provide observer implementations that could be used to gather metrics:
|
||||
//
|
||||
// - QueryObserver for monitoring individual queries.
|
||||
// - BatchObserver for monitoring batch queries.
|
||||
// - ConnectObserver for monitoring new connections from the driver to the database.
|
||||
// - FrameHeaderObserver for monitoring individual protocol frames.
|
||||
//
|
||||
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
|
||||
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
|
||||
// the session ID that the database used to store the trace information in system_traces.sessions and
|
||||
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
|
||||
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
|
||||
// so this feature should not be used on production systems with very high load unless you know what you are doing.
|
||||
// There is also a new implementation of Tracer - TracerEnhanced, that is intended to be more reliable and convinient to use.
|
||||
// It has a funcionality to check if trace is ready to be extracted and only actually gets it if requested which makes
|
||||
// the impact on a performance smaller.
|
||||
package gocql // import "github.com/gocql/gocql"
|
90
vendor/github.com/gocql/gocql/docker-compose.yml
generated
vendored
Normal file
90
vendor/github.com/gocql/gocql/docker-compose.yml
generated
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
version: "3.7"
|
||||
|
||||
services:
|
||||
node_1:
|
||||
image: ${SCYLLA_IMAGE}
|
||||
privileged: true
|
||||
command: |
|
||||
--smp 2
|
||||
--memory 768M
|
||||
--seeds 192.168.100.11
|
||||
--overprovisioned 1
|
||||
--experimental-features udf
|
||||
--enable-user-defined-functions true
|
||||
networks:
|
||||
public:
|
||||
ipv4_address: 192.168.100.11
|
||||
volumes:
|
||||
- /tmp/scylla:/var/lib/scylla/
|
||||
- type: bind
|
||||
source: ./testdata/config/scylla.yaml
|
||||
target: /etc/scylla/scylla.yaml
|
||||
- type: bind
|
||||
source: ./testdata/pki/ca.crt
|
||||
target: /etc/scylla/ca.crt
|
||||
- type: bind
|
||||
source: ./testdata/pki/cassandra.crt
|
||||
target: /etc/scylla/db.crt
|
||||
- type: bind
|
||||
source: ./testdata/pki/cassandra.key
|
||||
target: /etc/scylla/db.key
|
||||
healthcheck:
|
||||
test: [ "CMD", "cqlsh", "-e", "select * from system.local" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 18
|
||||
node_2:
|
||||
image: ${SCYLLA_IMAGE}
|
||||
command: |
|
||||
--smp 2
|
||||
--memory 1G
|
||||
--seeds 192.168.100.12
|
||||
networks:
|
||||
public:
|
||||
ipv4_address: 192.168.100.12
|
||||
healthcheck:
|
||||
test: [ "CMD", "cqlsh", "192.168.100.12", "-e", "select * from system.local" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 18
|
||||
node_3:
|
||||
image: ${SCYLLA_IMAGE}
|
||||
command: |
|
||||
--smp 2
|
||||
--memory 1G
|
||||
--seeds 192.168.100.12
|
||||
networks:
|
||||
public:
|
||||
ipv4_address: 192.168.100.13
|
||||
healthcheck:
|
||||
test: [ "CMD", "cqlsh", "192.168.100.13", "-e", "select * from system.local" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 18
|
||||
depends_on:
|
||||
node_2:
|
||||
condition: service_healthy
|
||||
node_4:
|
||||
image: ${SCYLLA_IMAGE}
|
||||
command: |
|
||||
--smp 2
|
||||
--memory 1G
|
||||
--seeds 192.168.100.12
|
||||
networks:
|
||||
public:
|
||||
ipv4_address: 192.168.100.14
|
||||
healthcheck:
|
||||
test: [ "CMD", "cqlsh", "192.168.100.14", "-e", "select * from system.local" ]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 18
|
||||
depends_on:
|
||||
node_3:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
public:
|
||||
driver: bridge
|
||||
ipam:
|
||||
driver: default
|
||||
config:
|
||||
- subnet: 192.168.100.0/24
|
227
vendor/github.com/gocql/gocql/errors.go
generated
vendored
Normal file
227
vendor/github.com/gocql/gocql/errors.go
generated
vendored
Normal file
@@ -0,0 +1,227 @@
|
||||
package gocql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// See CQL Binary Protocol v5, section 8 for more details.
|
||||
// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec
|
||||
const (
|
||||
// ErrCodeServer indicates unexpected error on server-side.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247
|
||||
ErrCodeServer = 0x0000
|
||||
// ErrCodeProtocol indicates a protocol violation by some client message.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250
|
||||
ErrCodeProtocol = 0x000A
|
||||
// ErrCodeCredentials indicates missing required authentication.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254
|
||||
ErrCodeCredentials = 0x0100
|
||||
// ErrCodeUnavailable indicates unavailable error.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265
|
||||
ErrCodeUnavailable = 0x1000
|
||||
// ErrCodeOverloaded returned in case of request on overloaded node coordinator.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267
|
||||
ErrCodeOverloaded = 0x1001
|
||||
// ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269
|
||||
ErrCodeBootstrapping = 0x1002
|
||||
// ErrCodeTruncate indicates truncation exception.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270
|
||||
ErrCodeTruncate = 0x1003
|
||||
// ErrCodeWriteTimeout returned in case of timeout during the request write.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304
|
||||
ErrCodeWriteTimeout = 0x1100
|
||||
// ErrCodeReadTimeout returned in case of timeout during the request read.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321
|
||||
ErrCodeReadTimeout = 0x1200
|
||||
// ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340
|
||||
ErrCodeReadFailure = 0x1300
|
||||
// ErrCodeFunctionFailure indicates an error in user-defined function.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347
|
||||
ErrCodeFunctionFailure = 0x1400
|
||||
// ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385
|
||||
ErrCodeWriteFailure = 0x1500
|
||||
// ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386
|
||||
ErrCodeCDCWriteFailure = 0x1600
|
||||
// ErrCodeCASWriteUnknown indicates only partially completed CAS operation.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
|
||||
ErrCodeCASWriteUnknown = 0x1700
|
||||
// ErrCodeSyntax indicates the syntax error in the query.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399
|
||||
ErrCodeSyntax = 0x2000
|
||||
// ErrCodeUnauthorized indicates access rights violation by user on performed operation.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401
|
||||
ErrCodeUnauthorized = 0x2100
|
||||
// ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402
|
||||
ErrCodeInvalid = 0x2200
|
||||
// ErrCodeConfig indicates the configuration error.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403
|
||||
ErrCodeConfig = 0x2300
|
||||
// ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413
|
||||
ErrCodeAlreadyExists = 0x2400
|
||||
// ErrCodeUnprepared returned from the host for prepared statement which is unknown.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417
|
||||
ErrCodeUnprepared = 0x2500
|
||||
)
|
||||
|
||||
type RequestError interface {
|
||||
Code() int
|
||||
Message() string
|
||||
Error() string
|
||||
}
|
||||
|
||||
type errorFrame struct {
|
||||
frameHeader
|
||||
|
||||
code int
|
||||
message string
|
||||
}
|
||||
|
||||
func (e errorFrame) Code() int {
|
||||
return e.code
|
||||
}
|
||||
|
||||
func (e errorFrame) Message() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
func (e errorFrame) Error() string {
|
||||
return e.Message()
|
||||
}
|
||||
|
||||
func (e errorFrame) String() string {
|
||||
return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message)
|
||||
}
|
||||
|
||||
type RequestErrUnavailable struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Required int
|
||||
Alive int
|
||||
}
|
||||
|
||||
func (e *RequestErrUnavailable) String() string {
|
||||
return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
|
||||
}
|
||||
|
||||
type ErrorMap map[string]uint16
|
||||
|
||||
type RequestErrWriteTimeout struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Received int
|
||||
BlockFor int
|
||||
WriteType string
|
||||
}
|
||||
|
||||
type RequestErrWriteFailure struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Received int
|
||||
BlockFor int
|
||||
NumFailures int
|
||||
WriteType string
|
||||
ErrorMap ErrorMap
|
||||
}
|
||||
|
||||
type RequestErrCDCWriteFailure struct {
|
||||
errorFrame
|
||||
}
|
||||
|
||||
type RequestErrReadTimeout struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Received int
|
||||
BlockFor int
|
||||
DataPresent byte
|
||||
}
|
||||
|
||||
type RequestErrAlreadyExists struct {
|
||||
errorFrame
|
||||
Keyspace string
|
||||
Table string
|
||||
}
|
||||
|
||||
type RequestErrUnprepared struct {
|
||||
errorFrame
|
||||
StatementId []byte
|
||||
}
|
||||
|
||||
type RequestErrReadFailure struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Received int
|
||||
BlockFor int
|
||||
NumFailures int
|
||||
DataPresent bool
|
||||
ErrorMap ErrorMap
|
||||
}
|
||||
|
||||
type RequestErrFunctionFailure struct {
|
||||
errorFrame
|
||||
Keyspace string
|
||||
Function string
|
||||
ArgTypes []string
|
||||
}
|
||||
|
||||
// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown.
|
||||
//
|
||||
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
|
||||
type RequestErrCASWriteUnknown struct {
|
||||
errorFrame
|
||||
Consistency Consistency
|
||||
Received int
|
||||
BlockFor int
|
||||
}
|
||||
|
||||
type UnknownServerError struct {
|
||||
errorFrame
|
||||
}
|
||||
|
||||
type OpType uint8
|
||||
|
||||
const (
|
||||
OpTypeRead OpType = 0
|
||||
OpTypeWrite OpType = 1
|
||||
)
|
||||
|
||||
type RequestErrRateLimitReached struct {
|
||||
errorFrame
|
||||
OpType OpType
|
||||
RejectedByCoordinator bool
|
||||
}
|
||||
|
||||
func (e *RequestErrRateLimitReached) String() string {
|
||||
var opType string
|
||||
if e.OpType == OpTypeRead {
|
||||
opType = "Read"
|
||||
} else if e.OpType == OpTypeWrite {
|
||||
opType = "Write"
|
||||
} else {
|
||||
opType = "Other"
|
||||
}
|
||||
return fmt.Sprintf("[request_error_rate_limit_reached OpType=%s RejectedByCoordinator=%t]", opType, e.RejectedByCoordinator)
|
||||
}
|
256
vendor/github.com/gocql/gocql/events.go
generated
vendored
Normal file
256
vendor/github.com/gocql/gocql/events.go
generated
vendored
Normal file
@@ -0,0 +1,256 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type eventDebouncer struct {
|
||||
name string
|
||||
timer *time.Timer
|
||||
mu sync.Mutex
|
||||
events []frame
|
||||
|
||||
callback func([]frame)
|
||||
quit chan struct{}
|
||||
|
||||
logger StdLogger
|
||||
}
|
||||
|
||||
func newEventDebouncer(name string, eventHandler func([]frame), logger StdLogger) *eventDebouncer {
|
||||
e := &eventDebouncer{
|
||||
name: name,
|
||||
quit: make(chan struct{}),
|
||||
timer: time.NewTimer(eventDebounceTime),
|
||||
callback: eventHandler,
|
||||
logger: logger,
|
||||
}
|
||||
e.timer.Stop()
|
||||
go e.flusher()
|
||||
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *eventDebouncer) stop() {
|
||||
e.quit <- struct{}{} // sync with flusher
|
||||
close(e.quit)
|
||||
}
|
||||
|
||||
func (e *eventDebouncer) flusher() {
|
||||
for {
|
||||
select {
|
||||
case <-e.timer.C:
|
||||
e.mu.Lock()
|
||||
e.flush()
|
||||
e.mu.Unlock()
|
||||
case <-e.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
eventBufferSize = 1000
|
||||
eventDebounceTime = 1 * time.Second
|
||||
)
|
||||
|
||||
// flush must be called with mu locked
|
||||
func (e *eventDebouncer) flush() {
|
||||
if len(e.events) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// if the flush interval is faster than the callback then we will end up calling
|
||||
// the callback multiple times, probably a bad idea. In this case we could drop
|
||||
// frames?
|
||||
go e.callback(e.events)
|
||||
e.events = make([]frame, 0, eventBufferSize)
|
||||
}
|
||||
|
||||
func (e *eventDebouncer) debounce(frame frame) {
|
||||
e.mu.Lock()
|
||||
e.timer.Reset(eventDebounceTime)
|
||||
|
||||
// TODO: probably need a warning to track if this threshold is too low
|
||||
if len(e.events) < eventBufferSize {
|
||||
e.events = append(e.events, frame)
|
||||
} else {
|
||||
e.logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
|
||||
}
|
||||
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Session) handleEvent(framer *framer) {
|
||||
frame, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
s.logger.Printf("gocql: unable to parse event frame: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if gocqlDebug {
|
||||
s.logger.Printf("gocql: handling frame: %v\n", frame)
|
||||
}
|
||||
|
||||
switch f := frame.(type) {
|
||||
case *schemaChangeKeyspace, *schemaChangeFunction,
|
||||
*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
|
||||
|
||||
s.schemaEvents.debounce(frame)
|
||||
case *topologyChangeEventFrame, *statusChangeEventFrame:
|
||||
s.nodeEvents.debounce(frame)
|
||||
default:
|
||||
s.logger.Printf("gocql: invalid event frame (%T): %v\n", f, f)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleSchemaEvent(frames []frame) {
|
||||
// TODO: debounce events
|
||||
for _, frame := range frames {
|
||||
switch f := frame.(type) {
|
||||
case *schemaChangeKeyspace:
|
||||
s.metadataDescriber.clearSchema(f.keyspace)
|
||||
s.handleKeyspaceChange(f.keyspace, f.change)
|
||||
case *schemaChangeTable:
|
||||
s.metadataDescriber.clearSchema(f.keyspace)
|
||||
s.handleTableChange(f.keyspace, f.object, f.change)
|
||||
case *schemaChangeAggregate:
|
||||
s.metadataDescriber.clearSchema(f.keyspace)
|
||||
case *schemaChangeFunction:
|
||||
s.metadataDescriber.clearSchema(f.keyspace)
|
||||
case *schemaChangeType:
|
||||
s.metadataDescriber.clearSchema(f.keyspace)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleKeyspaceChange(keyspace, change string) {
|
||||
s.control.awaitSchemaAgreement()
|
||||
if change == "DROPPED" || change == "UPDATED" {
|
||||
s.metadataDescriber.removeTabletsWithKeyspace(keyspace)
|
||||
}
|
||||
s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
|
||||
}
|
||||
|
||||
func (s *Session) handleTableChange(keyspace, table, change string) {
|
||||
if change == "DROPPED" || change == "UPDATED" {
|
||||
s.metadataDescriber.removeTabletsWithTable(keyspace, table)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNodeEvent handles inbound status and topology change events.
|
||||
//
|
||||
// Status events are debounced by host IP; only the latest event is processed.
|
||||
//
|
||||
// Topology events are debounced by performing a single full topology refresh
|
||||
// whenever any topology event comes in.
|
||||
//
|
||||
// Processing topology change events before status change events ensures
|
||||
// that a NEW_NODE event is not dropped in favor of a newer UP event (which
|
||||
// would itself be dropped/ignored, as the node is not yet known).
|
||||
func (s *Session) handleNodeEvent(frames []frame) {
|
||||
type nodeEvent struct {
|
||||
change string
|
||||
host net.IP
|
||||
port int
|
||||
}
|
||||
|
||||
topologyEventReceived := false
|
||||
// status change events
|
||||
sEvents := make(map[string]*nodeEvent)
|
||||
|
||||
for _, frame := range frames {
|
||||
switch f := frame.(type) {
|
||||
case *topologyChangeEventFrame:
|
||||
topologyEventReceived = true
|
||||
case *statusChangeEventFrame:
|
||||
event, ok := sEvents[f.host.String()]
|
||||
if !ok {
|
||||
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
|
||||
sEvents[f.host.String()] = event
|
||||
}
|
||||
event.change = f.change
|
||||
}
|
||||
}
|
||||
|
||||
if topologyEventReceived && !s.cfg.Events.DisableTopologyEvents {
|
||||
s.debounceRingRefresh()
|
||||
}
|
||||
|
||||
for _, f := range sEvents {
|
||||
if gocqlDebug {
|
||||
s.logger.Printf("gocql: dispatching status change event: %+v\n", f)
|
||||
}
|
||||
|
||||
// ignore events we received if they were disabled
|
||||
// see https://github.com/gocql/gocql/issues/1591
|
||||
switch f.change {
|
||||
case "UP":
|
||||
if !s.cfg.Events.DisableNodeStatusEvents {
|
||||
s.handleNodeUp(f.host, f.port)
|
||||
}
|
||||
case "DOWN":
|
||||
if !s.cfg.Events.DisableNodeStatusEvents {
|
||||
s.handleNodeDown(f.host, f.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
|
||||
if gocqlDebug {
|
||||
s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
|
||||
}
|
||||
|
||||
host, ok := s.hostSource.getHostByIP(eventIp.String())
|
||||
if !ok {
|
||||
s.debounceRingRefresh()
|
||||
return
|
||||
}
|
||||
|
||||
if s.cfg.filterHost(host) {
|
||||
return
|
||||
}
|
||||
|
||||
if d := host.Version().nodeUpDelay(); d > 0 {
|
||||
time.Sleep(d)
|
||||
}
|
||||
s.startPoolFill(host)
|
||||
}
|
||||
|
||||
func (s *Session) startPoolFill(host *HostInfo) {
|
||||
// we let the pool call handleNodeConnected to change the host state
|
||||
s.pool.addHost(host)
|
||||
s.policy.AddHost(host)
|
||||
}
|
||||
|
||||
func (s *Session) handleNodeConnected(host *HostInfo) {
|
||||
if gocqlDebug {
|
||||
s.logger.Printf("gocql: Session.handleNodeConnected: %s:%d\n", host.ConnectAddress(), host.Port())
|
||||
}
|
||||
|
||||
host.setState(NodeUp)
|
||||
|
||||
if !s.cfg.filterHost(host) {
|
||||
s.policy.HostUp(host)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleNodeDown(ip net.IP, port int) {
|
||||
if gocqlDebug {
|
||||
s.logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
|
||||
}
|
||||
|
||||
host, ok := s.hostSource.getHostByIP(ip.String())
|
||||
if ok {
|
||||
host.setState(NodeDown)
|
||||
if s.cfg.filterHost(host) {
|
||||
return
|
||||
}
|
||||
|
||||
s.policy.HostDown(host)
|
||||
hostID := host.HostID()
|
||||
s.pool.removeHost(hostID)
|
||||
}
|
||||
}
|
110
vendor/github.com/gocql/gocql/exec.go
generated
vendored
Normal file
110
vendor/github.com/gocql/gocql/exec.go
generated
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SingleHostQueryExecutor allows to quickly execute diagnostic queries while
|
||||
// connected to only a single node.
|
||||
// The executor opens only a single connection to a node and does not use
|
||||
// connection pools.
|
||||
// Consistency level used is ONE.
|
||||
// Retry policy is applied, attempts are visible in query metrics but query
|
||||
// observer is not notified.
|
||||
type SingleHostQueryExecutor struct {
|
||||
session *Session
|
||||
control *controlConn
|
||||
}
|
||||
|
||||
// Exec executes the query without returning any rows.
|
||||
func (e SingleHostQueryExecutor) Exec(stmt string, values ...interface{}) error {
|
||||
return e.control.query(stmt, values...).Close()
|
||||
}
|
||||
|
||||
// Iter executes the query and returns an iterator capable of iterating
|
||||
// over all results.
|
||||
func (e SingleHostQueryExecutor) Iter(stmt string, values ...interface{}) *Iter {
|
||||
return e.control.query(stmt, values...)
|
||||
}
|
||||
|
||||
func (e SingleHostQueryExecutor) Close() {
|
||||
if e.control != nil {
|
||||
e.control.close()
|
||||
}
|
||||
if e.session != nil {
|
||||
e.session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// NewSingleHostQueryExecutor creates a SingleHostQueryExecutor by connecting
|
||||
// to one of the hosts specified in the ClusterConfig.
|
||||
// If ProtoVersion is not specified version 4 is used.
|
||||
// Caller is responsible for closing the executor after use.
|
||||
func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor, err error) {
|
||||
// Check that hosts in the ClusterConfig is not empty
|
||||
if len(cfg.Hosts) < 1 {
|
||||
err = ErrNoHosts
|
||||
return
|
||||
}
|
||||
|
||||
c := *cfg
|
||||
|
||||
// If protocol version not set assume 4 and skip discovery
|
||||
if c.ProtoVersion == 0 {
|
||||
c.ProtoVersion = 4
|
||||
}
|
||||
|
||||
// Close in case of error
|
||||
defer func() {
|
||||
if err != nil {
|
||||
e.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Create uninitialised session
|
||||
c.disableInit = true
|
||||
if e.session, err = NewSession(c); err != nil {
|
||||
err = fmt.Errorf("new session: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
var hosts []*HostInfo
|
||||
if hosts, err = addrsToHosts(c.Hosts, c.Port, c.Logger); err != nil {
|
||||
err = fmt.Errorf("addrs to hosts: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create control connection to one of the hosts
|
||||
e.control = createControlConn(e.session)
|
||||
|
||||
// shuffle endpoints so not all drivers will connect to the same initial
|
||||
// node.
|
||||
hosts = shuffleHosts(hosts)
|
||||
|
||||
conncfg := *e.control.session.connCfg
|
||||
conncfg.disableCoalesce = true
|
||||
|
||||
var conn *Conn
|
||||
|
||||
for _, host := range hosts {
|
||||
conn, err = e.control.session.dial(e.control.session.ctx, host, &conncfg, e.control)
|
||||
if err != nil {
|
||||
e.control.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
continue
|
||||
}
|
||||
err = e.control.setupConn(conn)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
e.control.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
|
||||
conn.Close()
|
||||
conn = nil
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
err = fmt.Errorf("setup: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
57
vendor/github.com/gocql/gocql/filters.go
generated
vendored
Normal file
57
vendor/github.com/gocql/gocql/filters.go
generated
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
package gocql
|
||||
|
||||
import "fmt"
|
||||
|
||||
// HostFilter interface is used when a host is discovered via server sent events.
|
||||
type HostFilter interface {
|
||||
// Called when a new host is discovered, returning true will cause the host
|
||||
// to be added to the pools.
|
||||
Accept(host *HostInfo) bool
|
||||
}
|
||||
|
||||
// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
|
||||
type HostFilterFunc func(host *HostInfo) bool
|
||||
|
||||
func (fn HostFilterFunc) Accept(host *HostInfo) bool {
|
||||
return fn(host)
|
||||
}
|
||||
|
||||
// AcceptAllFilter will accept all hosts
|
||||
func AcceptAllFilter() HostFilter {
|
||||
return HostFilterFunc(func(host *HostInfo) bool {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func DenyAllFilter() HostFilter {
|
||||
return HostFilterFunc(func(host *HostInfo) bool {
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// DataCentreHostFilter filters all hosts such that they are in the same data centre
|
||||
// as the supplied data centre.
|
||||
func DataCentreHostFilter(dataCentre string) HostFilter {
|
||||
return HostFilterFunc(func(host *HostInfo) bool {
|
||||
return host.DataCenter() == dataCentre
|
||||
})
|
||||
}
|
||||
|
||||
// WhiteListHostFilter filters incoming hosts by checking that their address is
|
||||
// in the initial hosts whitelist.
|
||||
func WhiteListHostFilter(hosts ...string) HostFilter {
|
||||
hostInfos, err := addrsToHosts(hosts, 9042, nopLogger{})
|
||||
if err != nil {
|
||||
// dont want to panic here, but rather not break the API
|
||||
panic(fmt.Errorf("unable to lookup host info from address: %v", err))
|
||||
}
|
||||
|
||||
m := make(map[string]bool, len(hostInfos))
|
||||
for _, host := range hostInfos {
|
||||
m[host.ConnectAddress().String()] = true
|
||||
}
|
||||
|
||||
return HostFilterFunc(func(host *HostInfo) bool {
|
||||
return m[host.ConnectAddress().String()]
|
||||
})
|
||||
}
|
2119
vendor/github.com/gocql/gocql/frame.go
generated
vendored
Normal file
2119
vendor/github.com/gocql/gocql/frame.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
34
vendor/github.com/gocql/gocql/fuzz.go
generated
vendored
Normal file
34
vendor/github.com/gocql/gocql/fuzz.go
generated
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
//go:build gofuzz
|
||||
// +build gofuzz
|
||||
|
||||
package gocql
|
||||
|
||||
import "bytes"
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
var bw bytes.Buffer
|
||||
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
head, err := readHeader(r, make([]byte, 9))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
framer := newFramer(r, &bw, nil, byte(head.version))
|
||||
err = framer.readFrame(&head)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
frame, err := framer.parseFrame()
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
if frame != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 2
|
||||
}
|
448
vendor/github.com/gocql/gocql/helpers.go
generated
vendored
Normal file
448
vendor/github.com/gocql/gocql/helpers.go
generated
vendored
Normal file
@@ -0,0 +1,448 @@
|
||||
// Copyright (c) 2012 The gocql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gopkg.in/inf.v0"
|
||||
)
|
||||
|
||||
type RowData struct {
|
||||
Columns []string
|
||||
Values []interface{}
|
||||
}
|
||||
|
||||
func goType(t TypeInfo) (reflect.Type, error) {
|
||||
switch t.Type() {
|
||||
case TypeVarchar, TypeAscii, TypeInet, TypeText:
|
||||
return reflect.TypeOf(*new(string)), nil
|
||||
case TypeBigInt, TypeCounter:
|
||||
return reflect.TypeOf(*new(int64)), nil
|
||||
case TypeTime:
|
||||
return reflect.TypeOf(*new(time.Duration)), nil
|
||||
case TypeTimestamp:
|
||||
return reflect.TypeOf(*new(time.Time)), nil
|
||||
case TypeBlob:
|
||||
return reflect.TypeOf(*new([]byte)), nil
|
||||
case TypeBoolean:
|
||||
return reflect.TypeOf(*new(bool)), nil
|
||||
case TypeFloat:
|
||||
return reflect.TypeOf(*new(float32)), nil
|
||||
case TypeDouble:
|
||||
return reflect.TypeOf(*new(float64)), nil
|
||||
case TypeInt:
|
||||
return reflect.TypeOf(*new(int)), nil
|
||||
case TypeSmallInt:
|
||||
return reflect.TypeOf(*new(int16)), nil
|
||||
case TypeTinyInt:
|
||||
return reflect.TypeOf(*new(int8)), nil
|
||||
case TypeDecimal:
|
||||
return reflect.TypeOf(*new(*inf.Dec)), nil
|
||||
case TypeUUID, TypeTimeUUID:
|
||||
return reflect.TypeOf(*new(UUID)), nil
|
||||
case TypeList, TypeSet:
|
||||
elemType, err := goType(t.(CollectionType).Elem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reflect.SliceOf(elemType), nil
|
||||
case TypeMap:
|
||||
keyType, err := goType(t.(CollectionType).Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueType, err := goType(t.(CollectionType).Elem)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reflect.MapOf(keyType, valueType), nil
|
||||
case TypeVarint:
|
||||
return reflect.TypeOf(*new(*big.Int)), nil
|
||||
case TypeTuple:
|
||||
// what can we do here? all there is to do is to make a list of interface{}
|
||||
tuple := t.(TupleTypeInfo)
|
||||
return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil
|
||||
case TypeUDT:
|
||||
return reflect.TypeOf(make(map[string]interface{})), nil
|
||||
case TypeDate:
|
||||
return reflect.TypeOf(*new(time.Time)), nil
|
||||
case TypeDuration:
|
||||
return reflect.TypeOf(*new(Duration)), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
|
||||
}
|
||||
}
|
||||
|
||||
func dereference(i interface{}) interface{} {
|
||||
return reflect.Indirect(reflect.ValueOf(i)).Interface()
|
||||
}
|
||||
|
||||
func getCassandraBaseType(name string) Type {
|
||||
switch name {
|
||||
case "ascii":
|
||||
return TypeAscii
|
||||
case "bigint":
|
||||
return TypeBigInt
|
||||
case "blob":
|
||||
return TypeBlob
|
||||
case "boolean":
|
||||
return TypeBoolean
|
||||
case "counter":
|
||||
return TypeCounter
|
||||
case "date":
|
||||
return TypeDate
|
||||
case "decimal":
|
||||
return TypeDecimal
|
||||
case "double":
|
||||
return TypeDouble
|
||||
case "duration":
|
||||
return TypeDuration
|
||||
case "float":
|
||||
return TypeFloat
|
||||
case "int":
|
||||
return TypeInt
|
||||
case "smallint":
|
||||
return TypeSmallInt
|
||||
case "tinyint":
|
||||
return TypeTinyInt
|
||||
case "time":
|
||||
return TypeTime
|
||||
case "timestamp":
|
||||
return TypeTimestamp
|
||||
case "uuid":
|
||||
return TypeUUID
|
||||
case "varchar":
|
||||
return TypeVarchar
|
||||
case "text":
|
||||
return TypeText
|
||||
case "varint":
|
||||
return TypeVarint
|
||||
case "timeuuid":
|
||||
return TypeTimeUUID
|
||||
case "inet":
|
||||
return TypeInet
|
||||
case "MapType":
|
||||
return TypeMap
|
||||
case "ListType":
|
||||
return TypeList
|
||||
case "SetType":
|
||||
return TypeSet
|
||||
case "TupleType":
|
||||
return TypeTuple
|
||||
default:
|
||||
return TypeCustom
|
||||
}
|
||||
}
|
||||
|
||||
func getCassandraType(name string, logger StdLogger) TypeInfo {
|
||||
if strings.HasPrefix(name, "frozen<") {
|
||||
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger)
|
||||
} else if strings.HasPrefix(name, "set<") {
|
||||
return CollectionType{
|
||||
NativeType: NativeType{typ: TypeSet},
|
||||
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger),
|
||||
}
|
||||
} else if strings.HasPrefix(name, "list<") {
|
||||
return CollectionType{
|
||||
NativeType: NativeType{typ: TypeList},
|
||||
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger),
|
||||
}
|
||||
} else if strings.HasPrefix(name, "map<") {
|
||||
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
|
||||
if len(names) != 2 {
|
||||
logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
|
||||
return NativeType{
|
||||
typ: TypeCustom,
|
||||
}
|
||||
}
|
||||
return CollectionType{
|
||||
NativeType: NativeType{typ: TypeMap},
|
||||
Key: getCassandraType(names[0], logger),
|
||||
Elem: getCassandraType(names[1], logger),
|
||||
}
|
||||
} else if strings.HasPrefix(name, "tuple<") {
|
||||
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
|
||||
types := make([]TypeInfo, len(names))
|
||||
|
||||
for i, name := range names {
|
||||
types[i] = getCassandraType(name, logger)
|
||||
}
|
||||
|
||||
return TupleTypeInfo{
|
||||
NativeType: NativeType{typ: TypeTuple},
|
||||
Elems: types,
|
||||
}
|
||||
} else {
|
||||
return NativeType{
|
||||
typ: getCassandraBaseType(name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitCompositeTypes(name string) []string {
|
||||
if !strings.Contains(name, "<") {
|
||||
return strings.Split(name, ", ")
|
||||
}
|
||||
var parts []string
|
||||
lessCount := 0
|
||||
segment := ""
|
||||
for _, char := range name {
|
||||
if char == ',' && lessCount == 0 {
|
||||
if segment != "" {
|
||||
parts = append(parts, strings.TrimSpace(segment))
|
||||
}
|
||||
segment = ""
|
||||
continue
|
||||
}
|
||||
segment += string(char)
|
||||
if char == '<' {
|
||||
lessCount++
|
||||
} else if char == '>' {
|
||||
lessCount--
|
||||
}
|
||||
}
|
||||
if segment != "" {
|
||||
parts = append(parts, strings.TrimSpace(segment))
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
func apacheToCassandraType(t string) string {
|
||||
t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
|
||||
t = strings.Replace(t, "(", "<", -1)
|
||||
t = strings.Replace(t, ")", ">", -1)
|
||||
types := strings.FieldsFunc(t, func(r rune) bool {
|
||||
return r == '<' || r == '>' || r == ','
|
||||
})
|
||||
for _, typ := range types {
|
||||
t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
|
||||
}
|
||||
// This is done so it exactly matches what Cassandra returns
|
||||
return strings.Replace(t, ",", ", ", -1)
|
||||
}
|
||||
|
||||
func getApacheCassandraType(class string) Type {
|
||||
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
|
||||
case "AsciiType":
|
||||
return TypeAscii
|
||||
case "LongType":
|
||||
return TypeBigInt
|
||||
case "BytesType":
|
||||
return TypeBlob
|
||||
case "BooleanType":
|
||||
return TypeBoolean
|
||||
case "CounterColumnType":
|
||||
return TypeCounter
|
||||
case "DecimalType":
|
||||
return TypeDecimal
|
||||
case "DoubleType":
|
||||
return TypeDouble
|
||||
case "FloatType":
|
||||
return TypeFloat
|
||||
case "Int32Type":
|
||||
return TypeInt
|
||||
case "ShortType":
|
||||
return TypeSmallInt
|
||||
case "ByteType":
|
||||
return TypeTinyInt
|
||||
case "TimeType":
|
||||
return TypeTime
|
||||
case "DateType", "TimestampType":
|
||||
return TypeTimestamp
|
||||
case "UUIDType", "LexicalUUIDType":
|
||||
return TypeUUID
|
||||
case "UTF8Type":
|
||||
return TypeVarchar
|
||||
case "IntegerType":
|
||||
return TypeVarint
|
||||
case "TimeUUIDType":
|
||||
return TypeTimeUUID
|
||||
case "InetAddressType":
|
||||
return TypeInet
|
||||
case "MapType":
|
||||
return TypeMap
|
||||
case "ListType":
|
||||
return TypeList
|
||||
case "SetType":
|
||||
return TypeSet
|
||||
case "TupleType":
|
||||
return TypeTuple
|
||||
case "DurationType":
|
||||
return TypeDuration
|
||||
default:
|
||||
return TypeCustom
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RowData) rowMap(m map[string]interface{}) {
|
||||
for i, column := range r.Columns {
|
||||
val := dereference(r.Values[i])
|
||||
if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
|
||||
valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
|
||||
reflect.Copy(valCopy, valVal)
|
||||
m[column] = valCopy.Interface()
|
||||
} else {
|
||||
m[column] = val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TupeColumnName will return the column name of a tuple value in a column named
|
||||
// c at index n. It should be used if a specific element within a tuple is needed
|
||||
// to be extracted from a map returned from SliceMap or MapScan.
|
||||
func TupleColumnName(c string, n int) string {
|
||||
return fmt.Sprintf("%s[%d]", c, n)
|
||||
}
|
||||
|
||||
func (iter *Iter) RowData() (RowData, error) {
|
||||
if iter.err != nil {
|
||||
return RowData{}, iter.err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(iter.Columns()))
|
||||
values := make([]interface{}, 0, len(iter.Columns()))
|
||||
|
||||
for _, column := range iter.Columns() {
|
||||
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
|
||||
val, err := column.TypeInfo.NewWithError()
|
||||
if err != nil {
|
||||
return RowData{}, err
|
||||
}
|
||||
columns = append(columns, column.Name)
|
||||
values = append(values, val)
|
||||
} else {
|
||||
for i, elem := range c.Elems {
|
||||
columns = append(columns, TupleColumnName(column.Name, i))
|
||||
val, err := elem.NewWithError()
|
||||
if err != nil {
|
||||
return RowData{}, err
|
||||
}
|
||||
values = append(values, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rowData := RowData{
|
||||
Columns: columns,
|
||||
Values: values,
|
||||
}
|
||||
|
||||
return rowData, nil
|
||||
}
|
||||
|
||||
// TODO(zariel): is it worth exporting this?
|
||||
func (iter *Iter) rowMap() (map[string]interface{}, error) {
|
||||
if iter.err != nil {
|
||||
return nil, iter.err
|
||||
}
|
||||
|
||||
rowData, _ := iter.RowData()
|
||||
iter.Scan(rowData.Values...)
|
||||
m := make(map[string]interface{}, len(rowData.Columns))
|
||||
rowData.rowMap(m)
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// SliceMap is a helper function to make the API easier to use
|
||||
// returns the data from the query in the form of []map[string]interface{}
|
||||
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
|
||||
if iter.err != nil {
|
||||
return nil, iter.err
|
||||
}
|
||||
|
||||
// Not checking for the error because we just did
|
||||
rowData, _ := iter.RowData()
|
||||
dataToReturn := make([]map[string]interface{}, 0)
|
||||
for iter.Scan(rowData.Values...) {
|
||||
m := make(map[string]interface{}, len(rowData.Columns))
|
||||
rowData.rowMap(m)
|
||||
dataToReturn = append(dataToReturn, m)
|
||||
}
|
||||
if iter.err != nil {
|
||||
return nil, iter.err
|
||||
}
|
||||
return dataToReturn, nil
|
||||
}
|
||||
|
||||
// MapScan takes a map[string]interface{} and populates it with a row
|
||||
// that is returned from cassandra.
|
||||
//
|
||||
// Each call to MapScan() must be called with a new map object.
|
||||
// During the call to MapScan() any pointers in the existing map
|
||||
// are replaced with non pointer types before the call returns
|
||||
//
|
||||
// iter := session.Query(`SELECT * FROM mytable`).Iter()
|
||||
// for {
|
||||
// // New map each iteration
|
||||
// row := make(map[string]interface{})
|
||||
// if !iter.MapScan(row) {
|
||||
// break
|
||||
// }
|
||||
// // Do things with row
|
||||
// if fullname, ok := row["fullname"]; ok {
|
||||
// fmt.Printf("Full Name: %s\n", fullname)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// You can also pass pointers in the map before each call
|
||||
//
|
||||
// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
|
||||
// var address net.IP
|
||||
// var age int
|
||||
// iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
|
||||
// for {
|
||||
// // New map each iteration
|
||||
// row := map[string]interface{}{
|
||||
// "fullname": &fullName,
|
||||
// "age": &age,
|
||||
// "address": &address,
|
||||
// }
|
||||
// if !iter.MapScan(row) {
|
||||
// break
|
||||
// }
|
||||
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
|
||||
// }
|
||||
func (iter *Iter) MapScan(m map[string]interface{}) bool {
|
||||
if iter.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Not checking for the error because we just did
|
||||
rowData, _ := iter.RowData()
|
||||
|
||||
for i, col := range rowData.Columns {
|
||||
if dest, ok := m[col]; ok {
|
||||
rowData.Values[i] = dest
|
||||
}
|
||||
}
|
||||
|
||||
if iter.Scan(rowData.Values...) {
|
||||
rowData.rowMap(m)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func copyBytes(p []byte) []byte {
|
||||
b := make([]byte, len(p))
|
||||
copy(b, p)
|
||||
return b
|
||||
}
|
||||
|
||||
var failDNS = false
|
||||
|
||||
func LookupIP(host string) ([]net.IP, error) {
|
||||
if failDNS {
|
||||
return nil, &net.DNSError{}
|
||||
}
|
||||
return net.LookupIP(host)
|
||||
|
||||
}
|
722
vendor/github.com/gocql/gocql/host_source.go
generated
vendored
Normal file
722
vendor/github.com/gocql/gocql/host_source.go
generated
vendored
Normal file
@@ -0,0 +1,722 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrCannotFindHost = errors.New("cannot find host")
|
||||
var ErrHostAlreadyExists = errors.New("host already exists")
|
||||
|
||||
type nodeState int32
|
||||
|
||||
func (n nodeState) String() string {
|
||||
if n == NodeUp {
|
||||
return "UP"
|
||||
} else if n == NodeDown {
|
||||
return "DOWN"
|
||||
}
|
||||
return fmt.Sprintf("UNKNOWN_%d", n)
|
||||
}
|
||||
|
||||
const (
|
||||
NodeUp nodeState = iota
|
||||
NodeDown
|
||||
)
|
||||
|
||||
type cassVersion struct {
|
||||
Major, Minor, Patch int
|
||||
}
|
||||
|
||||
func (c *cassVersion) Set(v string) error {
|
||||
if v == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.UnmarshalCQL(nil, []byte(v))
|
||||
}
|
||||
|
||||
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
|
||||
return c.unmarshal(data)
|
||||
}
|
||||
|
||||
func (c *cassVersion) unmarshal(data []byte) error {
|
||||
version := strings.TrimSuffix(string(data), "-SNAPSHOT")
|
||||
version = strings.TrimPrefix(version, "v")
|
||||
v := strings.Split(version, ".")
|
||||
|
||||
if len(v) < 2 {
|
||||
return fmt.Errorf("invalid version string: %s", data)
|
||||
}
|
||||
|
||||
var err error
|
||||
c.Major, err = strconv.Atoi(v[0])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid major version %v: %v", v[0], err)
|
||||
}
|
||||
|
||||
c.Minor, err = strconv.Atoi(v[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid minor version %v: %v", v[1], err)
|
||||
}
|
||||
|
||||
if len(v) > 2 {
|
||||
c.Patch, err = strconv.Atoi(v[2])
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid patch version %v: %v", v[2], err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c cassVersion) Before(major, minor, patch int) bool {
|
||||
// We're comparing us (cassVersion) with the provided version (major, minor, patch)
|
||||
// We return true if our version is lower (comes before) than the provided one.
|
||||
if c.Major < major {
|
||||
return true
|
||||
} else if c.Major == major {
|
||||
if c.Minor < minor {
|
||||
return true
|
||||
} else if c.Minor == minor && c.Patch < patch {
|
||||
return true
|
||||
}
|
||||
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c cassVersion) AtLeast(major, minor, patch int) bool {
|
||||
return !c.Before(major, minor, patch)
|
||||
}
|
||||
|
||||
func (c cassVersion) String() string {
|
||||
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
|
||||
}
|
||||
|
||||
func (c cassVersion) nodeUpDelay() time.Duration {
|
||||
if c.Major >= 2 && c.Minor >= 2 {
|
||||
// CASSANDRA-8236
|
||||
return 0
|
||||
}
|
||||
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
|
||||
// that we are thread safe use a mutex to access all fields.
|
||||
mu sync.RWMutex
|
||||
hostname string
|
||||
peer net.IP
|
||||
broadcastAddress net.IP
|
||||
listenAddress net.IP
|
||||
rpcAddress net.IP
|
||||
preferredIP net.IP
|
||||
connectAddress net.IP
|
||||
untranslatedConnectAddress net.IP
|
||||
port int
|
||||
dataCenter string
|
||||
rack string
|
||||
hostId string
|
||||
workload string
|
||||
graph bool
|
||||
dseVersion string
|
||||
partitioner string
|
||||
clusterName string
|
||||
version cassVersion
|
||||
state nodeState
|
||||
schemaVersion string
|
||||
tokens []string
|
||||
|
||||
scyllaShardAwarePort uint16
|
||||
scyllaShardAwarePortTLS uint16
|
||||
}
|
||||
|
||||
func (h *HostInfo) Equal(host *HostInfo) bool {
|
||||
if h == host {
|
||||
// prevent rlock reentry
|
||||
return true
|
||||
}
|
||||
|
||||
return h.HostID() == host.HostID() && h.ConnectAddressAndPort() == host.ConnectAddressAndPort()
|
||||
}
|
||||
|
||||
func (h *HostInfo) Peer() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.peer
|
||||
}
|
||||
|
||||
func (h *HostInfo) invalidConnectAddr() bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
addr, _ := h.connectAddressLocked()
|
||||
return !validIpAddr(addr)
|
||||
}
|
||||
|
||||
func validIpAddr(addr net.IP) bool {
|
||||
return addr != nil && !addr.IsUnspecified()
|
||||
}
|
||||
|
||||
func (h *HostInfo) connectAddressLocked() (net.IP, string) {
|
||||
if validIpAddr(h.connectAddress) {
|
||||
return h.connectAddress, "connect_address"
|
||||
} else if validIpAddr(h.rpcAddress) {
|
||||
return h.rpcAddress, "rpc_adress"
|
||||
} else if validIpAddr(h.preferredIP) {
|
||||
// where does perferred_ip get set?
|
||||
return h.preferredIP, "preferred_ip"
|
||||
} else if validIpAddr(h.broadcastAddress) {
|
||||
return h.broadcastAddress, "broadcast_address"
|
||||
} else if validIpAddr(h.peer) {
|
||||
return h.peer, "peer"
|
||||
}
|
||||
return net.IPv4zero, "invalid"
|
||||
}
|
||||
|
||||
// nodeToNodeAddress returns address broadcasted between node to nodes.
|
||||
// It's either `broadcast_address` if host info is read from system.local or `peer` if read from system.peers.
|
||||
// This IP address is also part of CQL Event emitted on topology/status changes,
|
||||
// but does not uniquely identify the node in case multiple nodes use the same IP address.
|
||||
func (h *HostInfo) nodeToNodeAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if validIpAddr(h.broadcastAddress) {
|
||||
return h.broadcastAddress
|
||||
} else if validIpAddr(h.peer) {
|
||||
return h.peer
|
||||
}
|
||||
return net.IPv4zero
|
||||
}
|
||||
|
||||
// Returns the address that should be used to connect to the host.
|
||||
// If you wish to override this, use an AddressTranslator or
|
||||
// use a HostFilter to SetConnectAddress()
|
||||
func (h *HostInfo) ConnectAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
|
||||
return addr
|
||||
}
|
||||
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
|
||||
}
|
||||
|
||||
func (h *HostInfo) UntranslatedConnectAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
if len(h.untranslatedConnectAddress) != 0 {
|
||||
return h.untranslatedConnectAddress
|
||||
}
|
||||
|
||||
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
|
||||
return addr
|
||||
}
|
||||
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
|
||||
}
|
||||
|
||||
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
|
||||
// TODO(zariel): should this not be exported?
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.connectAddress = address
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) BroadcastAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.broadcastAddress
|
||||
}
|
||||
|
||||
func (h *HostInfo) ListenAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.listenAddress
|
||||
}
|
||||
|
||||
func (h *HostInfo) RPCAddress() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.rpcAddress
|
||||
}
|
||||
|
||||
func (h *HostInfo) PreferredIP() net.IP {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.preferredIP
|
||||
}
|
||||
|
||||
func (h *HostInfo) DataCenter() string {
|
||||
h.mu.RLock()
|
||||
dc := h.dataCenter
|
||||
h.mu.RUnlock()
|
||||
return dc
|
||||
}
|
||||
|
||||
func (h *HostInfo) Rack() string {
|
||||
h.mu.RLock()
|
||||
rack := h.rack
|
||||
h.mu.RUnlock()
|
||||
return rack
|
||||
}
|
||||
|
||||
func (h *HostInfo) HostID() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.hostId
|
||||
}
|
||||
|
||||
func (h *HostInfo) SetHostID(hostID string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.hostId = hostID
|
||||
}
|
||||
|
||||
func (h *HostInfo) WorkLoad() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.workload
|
||||
}
|
||||
|
||||
func (h *HostInfo) Graph() bool {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.graph
|
||||
}
|
||||
|
||||
func (h *HostInfo) DSEVersion() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.dseVersion
|
||||
}
|
||||
|
||||
func (h *HostInfo) Partitioner() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.partitioner
|
||||
}
|
||||
|
||||
func (h *HostInfo) ClusterName() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.clusterName
|
||||
}
|
||||
|
||||
func (h *HostInfo) Version() cassVersion {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.version
|
||||
}
|
||||
|
||||
func (h *HostInfo) State() nodeState {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.state
|
||||
}
|
||||
|
||||
func (h *HostInfo) setState(state nodeState) *HostInfo {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.state = state
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *HostInfo) Tokens() []string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.tokens
|
||||
}
|
||||
|
||||
func (h *HostInfo) Port() int {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.port
|
||||
}
|
||||
|
||||
func (h *HostInfo) update(from *HostInfo) {
|
||||
if h == from {
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
from.mu.RLock()
|
||||
defer from.mu.RUnlock()
|
||||
|
||||
// autogenerated do not update
|
||||
if h.peer == nil {
|
||||
h.peer = from.peer
|
||||
}
|
||||
if h.broadcastAddress == nil {
|
||||
h.broadcastAddress = from.broadcastAddress
|
||||
}
|
||||
if h.listenAddress == nil {
|
||||
h.listenAddress = from.listenAddress
|
||||
}
|
||||
if h.rpcAddress == nil {
|
||||
h.rpcAddress = from.rpcAddress
|
||||
}
|
||||
if h.preferredIP == nil {
|
||||
h.preferredIP = from.preferredIP
|
||||
}
|
||||
if h.connectAddress == nil {
|
||||
h.connectAddress = from.connectAddress
|
||||
}
|
||||
if h.port == 0 {
|
||||
h.port = from.port
|
||||
}
|
||||
if h.dataCenter == "" {
|
||||
h.dataCenter = from.dataCenter
|
||||
}
|
||||
if h.rack == "" {
|
||||
h.rack = from.rack
|
||||
}
|
||||
if h.hostId == "" {
|
||||
h.hostId = from.hostId
|
||||
}
|
||||
if h.workload == "" {
|
||||
h.workload = from.workload
|
||||
}
|
||||
if h.dseVersion == "" {
|
||||
h.dseVersion = from.dseVersion
|
||||
}
|
||||
if h.partitioner == "" {
|
||||
h.partitioner = from.partitioner
|
||||
}
|
||||
if h.clusterName == "" {
|
||||
h.clusterName = from.clusterName
|
||||
}
|
||||
if h.version == (cassVersion{}) {
|
||||
h.version = from.version
|
||||
}
|
||||
if h.tokens == nil {
|
||||
h.tokens = from.tokens
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HostInfo) IsUp() bool {
|
||||
return h != nil && h.State() == NodeUp
|
||||
}
|
||||
|
||||
func (h *HostInfo) IsBusy(s *Session) bool {
|
||||
pool, ok := s.pool.getPool(h)
|
||||
return ok && h != nil && pool.InFlight() >= MAX_IN_FLIGHT_THRESHOLD
|
||||
}
|
||||
|
||||
func (h *HostInfo) HostnameAndPort() string {
|
||||
// Fast path: in most cases hostname is not empty
|
||||
var (
|
||||
hostname string
|
||||
port int
|
||||
)
|
||||
h.mu.RLock()
|
||||
hostname = h.hostname
|
||||
port = h.port
|
||||
h.mu.RUnlock()
|
||||
|
||||
if hostname != "" {
|
||||
return net.JoinHostPort(hostname, strconv.Itoa(port))
|
||||
}
|
||||
|
||||
// Slow path: hostname is empty
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.hostname == "" { // recheck is hostname empty
|
||||
// if yes - fill it
|
||||
addr, _ := h.connectAddressLocked()
|
||||
h.hostname = addr.String()
|
||||
}
|
||||
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
|
||||
}
|
||||
|
||||
func (h *HostInfo) Hostname() string {
|
||||
// Fast path: in most cases hostname is not empty
|
||||
var hostname string
|
||||
h.mu.RLock()
|
||||
hostname = h.hostname
|
||||
h.mu.RUnlock()
|
||||
|
||||
if hostname != "" {
|
||||
return hostname
|
||||
}
|
||||
|
||||
// Slow path: hostname is empty
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if h.hostname == "" {
|
||||
addr, _ := h.connectAddressLocked()
|
||||
h.hostname = addr.String()
|
||||
}
|
||||
return h.hostname
|
||||
}
|
||||
|
||||
func (h *HostInfo) ConnectAddressAndPort() string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
addr, _ := h.connectAddressLocked()
|
||||
return net.JoinHostPort(addr.String(), strconv.Itoa(h.port))
|
||||
}
|
||||
|
||||
func (h *HostInfo) String() string {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
|
||||
connectAddr, source := h.connectAddressLocked()
|
||||
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
|
||||
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
|
||||
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
|
||||
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
|
||||
connectAddr, source,
|
||||
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
|
||||
}
|
||||
|
||||
func (h *HostInfo) setScyllaSupported(s scyllaSupported) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.scyllaShardAwarePort = s.shardAwarePort
|
||||
h.scyllaShardAwarePortTLS = s.shardAwarePortSSL
|
||||
}
|
||||
|
||||
// ScyllaShardAwarePort returns the shard aware port of this host.
|
||||
// Returns zero if the shard aware port is not known.
|
||||
func (h *HostInfo) ScyllaShardAwarePort() uint16 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.scyllaShardAwarePort
|
||||
}
|
||||
|
||||
// ScyllaShardAwarePortTLS returns the TLS-enabled shard aware port of this host.
|
||||
// Returns zero if the shard aware port is not known.
|
||||
func (h *HostInfo) ScyllaShardAwarePortTLS() uint16 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
return h.scyllaShardAwarePortTLS
|
||||
}
|
||||
|
||||
// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
|
||||
func checkSystemSchema(control controlConnection) (bool, error) {
|
||||
iter := control.query("SELECT * FROM system_schema.keyspaces" + control.getSession().usingTimeoutClause)
|
||||
if err := iter.err; err != nil {
|
||||
if errf, ok := err.(*errorFrame); ok {
|
||||
if errf.code == ErrCodeSyntax {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// Given a map that represents a row from either system.local or system.peers
|
||||
// return as much information as we can in *HostInfo
|
||||
func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
|
||||
const assertErrorMsg = "Assertion failed for %s"
|
||||
var ok bool
|
||||
|
||||
// Default to our connected port if the cluster doesn't have port information
|
||||
for key, value := range row {
|
||||
switch key {
|
||||
case "data_center":
|
||||
host.dataCenter, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "data_center")
|
||||
}
|
||||
case "rack":
|
||||
host.rack, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "rack")
|
||||
}
|
||||
case "host_id":
|
||||
hostId, ok := value.(UUID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "host_id")
|
||||
}
|
||||
host.hostId = hostId.String()
|
||||
case "release_version":
|
||||
version, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "release_version")
|
||||
}
|
||||
host.version.Set(version)
|
||||
case "peer":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "peer")
|
||||
}
|
||||
host.peer = net.ParseIP(ip)
|
||||
case "cluster_name":
|
||||
host.clusterName, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
|
||||
}
|
||||
case "partitioner":
|
||||
host.partitioner, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "partitioner")
|
||||
}
|
||||
case "broadcast_address":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
|
||||
}
|
||||
host.broadcastAddress = net.ParseIP(ip)
|
||||
case "preferred_ip":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
|
||||
}
|
||||
host.preferredIP = net.ParseIP(ip)
|
||||
case "rpc_address":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
|
||||
}
|
||||
host.rpcAddress = net.ParseIP(ip)
|
||||
case "native_address":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "native_address")
|
||||
}
|
||||
host.rpcAddress = net.ParseIP(ip)
|
||||
case "listen_address":
|
||||
ip, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "listen_address")
|
||||
}
|
||||
host.listenAddress = net.ParseIP(ip)
|
||||
case "native_port":
|
||||
native_port, ok := value.(int)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "native_port")
|
||||
}
|
||||
host.port = native_port
|
||||
case "workload":
|
||||
host.workload, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "workload")
|
||||
}
|
||||
case "graph":
|
||||
host.graph, ok = value.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "graph")
|
||||
}
|
||||
case "tokens":
|
||||
host.tokens, ok = value.([]string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "tokens")
|
||||
}
|
||||
case "dse_version":
|
||||
host.dseVersion, ok = value.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
|
||||
}
|
||||
case "schema_version":
|
||||
schemaVersion, ok := value.(UUID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
|
||||
}
|
||||
host.schemaVersion = schemaVersion.String()
|
||||
}
|
||||
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
|
||||
// Not sure what the port field will be called until the JIRA issue is complete
|
||||
}
|
||||
|
||||
host.untranslatedConnectAddress = host.ConnectAddress()
|
||||
ip, port := translateAddressPort(host.untranslatedConnectAddress, host.port)
|
||||
host.connectAddress = ip
|
||||
host.port = port
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
|
||||
rows, err := iter.SliceMap()
|
||||
if err != nil {
|
||||
// TODO(zariel): make typed error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return nil, errors.New("query returned 0 rows")
|
||||
}
|
||||
|
||||
host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, translateAddressPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// debounceRingRefresh submits a ring refresh request to the ring refresh debouncer.
|
||||
func (s *Session) debounceRingRefresh() {
|
||||
s.ringRefresher.Debounce()
|
||||
}
|
||||
|
||||
// refreshRing executes a ring refresh immediately and cancels pending debounce ring refresh requests.
|
||||
func (s *Session) refreshRingNow() error {
|
||||
err, ok := <-s.ringRefresher.RefreshNow()
|
||||
if !ok {
|
||||
return errors.New("could not refresh ring because stop was requested")
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Session) refreshRing() error {
|
||||
hosts, partitioner, err := s.hostSource.GetHostsFromSystem()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prevHosts := s.hostSource.getHostsMap()
|
||||
|
||||
for _, h := range hosts {
|
||||
if s.cfg.filterHost(h) {
|
||||
continue
|
||||
}
|
||||
|
||||
if host, ok := s.hostSource.addHostIfMissing(h); !ok {
|
||||
s.startPoolFill(h)
|
||||
} else {
|
||||
// host (by hostID) already exists; determine if IP has changed
|
||||
newHostID := h.HostID()
|
||||
existing, ok := prevHosts[newHostID]
|
||||
if !ok {
|
||||
return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost)
|
||||
}
|
||||
if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) {
|
||||
// no host IP change
|
||||
host.update(h)
|
||||
} else {
|
||||
// host IP has changed
|
||||
// remove old HostInfo (w/old IP)
|
||||
s.removeHost(existing)
|
||||
if _, alreadyExists := s.hostSource.addHostIfMissing(h); alreadyExists {
|
||||
return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists)
|
||||
}
|
||||
// add new HostInfo (same hostID, new IP)
|
||||
s.startPoolFill(h)
|
||||
}
|
||||
}
|
||||
delete(prevHosts, h.HostID())
|
||||
}
|
||||
|
||||
for _, host := range prevHosts {
|
||||
s.metadataDescriber.removeTabletsWithHost(host)
|
||||
s.removeHost(host)
|
||||
}
|
||||
s.policy.SetPartitioner(partitioner)
|
||||
|
||||
return nil
|
||||
}
|
46
vendor/github.com/gocql/gocql/host_source_gen.go
generated
vendored
Normal file
46
vendor/github.com/gocql/gocql/host_source_gen.go
generated
vendored
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build genhostinfo
|
||||
// +build genhostinfo
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
func gen(clause, field string) {
|
||||
fmt.Printf("if h.%s == %s {\n", field, clause)
|
||||
fmt.Printf("\th.%s = from.%s\n", field, field)
|
||||
fmt.Println("}")
|
||||
}
|
||||
|
||||
func main() {
|
||||
t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
|
||||
mu := reflect.TypeOf(sync.RWMutex{})
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
f := t.Field(i)
|
||||
if f.Type == mu {
|
||||
continue
|
||||
}
|
||||
|
||||
switch f.Type.Kind() {
|
||||
case reflect.Slice:
|
||||
gen("nil", f.Name)
|
||||
case reflect.String:
|
||||
gen(`""`, f.Name)
|
||||
case reflect.Int:
|
||||
gen("0", f.Name)
|
||||
case reflect.Struct:
|
||||
gen("("+f.Type.Name()+"{})", f.Name)
|
||||
case reflect.Bool, reflect.Int32:
|
||||
continue
|
||||
default:
|
||||
panic(fmt.Sprintf("unknown field: %s", f))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
7
vendor/github.com/gocql/gocql/host_source_scylla.go
generated
vendored
Normal file
7
vendor/github.com/gocql/gocql/host_source_scylla.go
generated
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
package gocql
|
||||
|
||||
func (h *HostInfo) SetDatacenter(dc string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.dataCenter = dc
|
||||
}
|
0
vendor/github.com/gocql/gocql/install_test_deps.sh
generated
vendored
Normal file
0
vendor/github.com/gocql/gocql/install_test_deps.sh
generated
vendored
Normal file
54
vendor/github.com/gocql/gocql/integration.sh
generated
vendored
Normal file
54
vendor/github.com/gocql/gocql/integration.sh
generated
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright (C) 2017 ScyllaDB
|
||||
#
|
||||
|
||||
readonly SCYLLA_IMAGE=${SCYLLA_IMAGE}
|
||||
|
||||
set -eu -o pipefail
|
||||
|
||||
function scylla_up() {
|
||||
local -r exec="docker compose exec -T"
|
||||
|
||||
echo "==> Running Scylla ${SCYLLA_IMAGE}"
|
||||
docker pull ${SCYLLA_IMAGE}
|
||||
docker compose up -d --wait || ( docker compose ps --format json | jq -M 'select(.Health == "unhealthy") | .Service' | xargs docker compose logs; exit 1 )
|
||||
}
|
||||
|
||||
function scylla_down() {
|
||||
echo "==> Stopping Scylla"
|
||||
docker compose down
|
||||
}
|
||||
|
||||
function scylla_restart() {
|
||||
scylla_down
|
||||
scylla_up
|
||||
}
|
||||
|
||||
scylla_restart
|
||||
|
||||
sudo chmod 0777 /tmp/scylla/cql.m
|
||||
|
||||
readonly clusterSize=1
|
||||
readonly multiNodeClusterSize=3
|
||||
readonly scylla_liveset="192.168.100.11"
|
||||
readonly scylla_tablet_liveset="192.168.100.12"
|
||||
readonly cversion="3.11.4"
|
||||
readonly proto=4
|
||||
readonly args="-gocql.timeout=60s -proto=${proto} -rf=${clusterSize} -clusterSize=${clusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -cluster=${scylla_liveset}"
|
||||
readonly tabletArgs="-gocql.timeout=60s -proto=${proto} -rf=1 -clusterSize=${multiNodeClusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -multiCluster=${scylla_tablet_liveset}"
|
||||
|
||||
if [[ "$*" == *"tablet"* ]];
|
||||
then
|
||||
echo "==> Running tablet tests with args: ${tabletArgs}"
|
||||
go test -timeout=5m -race -tags="tablet" ${tabletArgs} ./...
|
||||
fi
|
||||
|
||||
TAGS=$*
|
||||
TAGS=${TAGS//"tablet"/}
|
||||
|
||||
if [ ! -z "$TAGS" ];
|
||||
then
|
||||
echo "==> Running ${TAGS} tests with args: ${args}"
|
||||
go test -timeout=5m -race -tags="$TAGS" ${args} ./...
|
||||
fi
|
127
vendor/github.com/gocql/gocql/internal/lru/lru.go
generated
vendored
Normal file
127
vendor/github.com/gocql/gocql/internal/lru/lru.go
generated
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
/*
|
||||
Copyright 2015 To gocql authors
|
||||
Copyright 2013 Google Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
// Package lru implements an LRU cache.
|
||||
package lru
|
||||
|
||||
import "container/list"
|
||||
|
||||
// Cache is an LRU cache. It is not safe for concurrent access.
|
||||
//
|
||||
// This cache has been forked from github.com/golang/groupcache/lru, but
|
||||
// specialized with string keys to avoid the allocations caused by wrapping them
|
||||
// in interface{}.
|
||||
type Cache struct {
|
||||
// MaxEntries is the maximum number of cache entries before
|
||||
// an item is evicted. Zero means no limit.
|
||||
MaxEntries int
|
||||
|
||||
// OnEvicted optionally specifies a callback function to be
|
||||
// executed when an entry is purged from the cache.
|
||||
OnEvicted func(key string, value interface{})
|
||||
|
||||
ll *list.List
|
||||
cache map[string]*list.Element
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
key string
|
||||
value interface{}
|
||||
}
|
||||
|
||||
// New creates a new Cache.
|
||||
// If maxEntries is zero, the cache has no limit and it's assumed
|
||||
// that eviction is done by the caller.
|
||||
func New(maxEntries int) *Cache {
|
||||
return &Cache{
|
||||
MaxEntries: maxEntries,
|
||||
ll: list.New(),
|
||||
cache: make(map[string]*list.Element),
|
||||
}
|
||||
}
|
||||
|
||||
// Add adds a value to the cache.
|
||||
func (c *Cache) Add(key string, value interface{}) {
|
||||
if c.cache == nil {
|
||||
c.cache = make(map[string]*list.Element)
|
||||
c.ll = list.New()
|
||||
}
|
||||
if ee, ok := c.cache[key]; ok {
|
||||
c.ll.MoveToFront(ee)
|
||||
ee.Value.(*entry).value = value
|
||||
return
|
||||
}
|
||||
ele := c.ll.PushFront(&entry{key, value})
|
||||
c.cache[key] = ele
|
||||
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
|
||||
c.RemoveOldest()
|
||||
}
|
||||
}
|
||||
|
||||
// Get looks up a key's value from the cache.
|
||||
func (c *Cache) Get(key string) (value interface{}, ok bool) {
|
||||
if c.cache == nil {
|
||||
return
|
||||
}
|
||||
if ele, hit := c.cache[key]; hit {
|
||||
c.ll.MoveToFront(ele)
|
||||
return ele.Value.(*entry).value, true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Remove removes the provided key from the cache.
|
||||
func (c *Cache) Remove(key string) bool {
|
||||
if c.cache == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if ele, hit := c.cache[key]; hit {
|
||||
c.removeElement(ele)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// RemoveOldest removes the oldest item from the cache.
|
||||
func (c *Cache) RemoveOldest() {
|
||||
if c.cache == nil {
|
||||
return
|
||||
}
|
||||
ele := c.ll.Back()
|
||||
if ele != nil {
|
||||
c.removeElement(ele)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache) removeElement(e *list.Element) {
|
||||
c.ll.Remove(e)
|
||||
kv := e.Value.(*entry)
|
||||
delete(c.cache, kv.key)
|
||||
if c.OnEvicted != nil {
|
||||
c.OnEvicted(kv.key, kv.value)
|
||||
}
|
||||
}
|
||||
|
||||
// Len returns the number of items in the cache.
|
||||
func (c *Cache) Len() int {
|
||||
if c.cache == nil {
|
||||
return 0
|
||||
}
|
||||
return c.ll.Len()
|
||||
}
|
135
vendor/github.com/gocql/gocql/internal/murmur/murmur.go
generated
vendored
Normal file
135
vendor/github.com/gocql/gocql/internal/murmur/murmur.go
generated
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
package murmur
|
||||
|
||||
const (
|
||||
c1 int64 = -8663945395140668459 // 0x87c37b91114253d5
|
||||
c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f
|
||||
fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd
|
||||
fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
|
||||
)
|
||||
|
||||
func fmix(n int64) int64 {
|
||||
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
|
||||
n ^= int64(uint64(n) >> 33)
|
||||
n *= fmix1
|
||||
n ^= int64(uint64(n) >> 33)
|
||||
n *= fmix2
|
||||
n ^= int64(uint64(n) >> 33)
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
func block(p byte) int64 {
|
||||
return int64(int8(p))
|
||||
}
|
||||
|
||||
func rotl(x int64, r uint8) int64 {
|
||||
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
|
||||
return (x << r) | (int64)((uint64(x) >> (64 - r)))
|
||||
}
|
||||
|
||||
func Murmur3H1(data []byte) int64 {
|
||||
length := len(data)
|
||||
|
||||
var h1, h2, k1, k2 int64
|
||||
|
||||
// body
|
||||
nBlocks := length / 16
|
||||
for i := 0; i < nBlocks; i++ {
|
||||
k1, k2 = getBlock(data, i)
|
||||
|
||||
k1 *= c1
|
||||
k1 = rotl(k1, 31)
|
||||
k1 *= c2
|
||||
h1 ^= k1
|
||||
|
||||
h1 = rotl(h1, 27)
|
||||
h1 += h2
|
||||
h1 = h1*5 + 0x52dce729
|
||||
|
||||
k2 *= c2
|
||||
k2 = rotl(k2, 33)
|
||||
k2 *= c1
|
||||
h2 ^= k2
|
||||
|
||||
h2 = rotl(h2, 31)
|
||||
h2 += h1
|
||||
h2 = h2*5 + 0x38495ab5
|
||||
}
|
||||
|
||||
// tail
|
||||
tail := data[nBlocks*16:]
|
||||
k1 = 0
|
||||
k2 = 0
|
||||
switch length & 15 {
|
||||
case 15:
|
||||
k2 ^= block(tail[14]) << 48
|
||||
fallthrough
|
||||
case 14:
|
||||
k2 ^= block(tail[13]) << 40
|
||||
fallthrough
|
||||
case 13:
|
||||
k2 ^= block(tail[12]) << 32
|
||||
fallthrough
|
||||
case 12:
|
||||
k2 ^= block(tail[11]) << 24
|
||||
fallthrough
|
||||
case 11:
|
||||
k2 ^= block(tail[10]) << 16
|
||||
fallthrough
|
||||
case 10:
|
||||
k2 ^= block(tail[9]) << 8
|
||||
fallthrough
|
||||
case 9:
|
||||
k2 ^= block(tail[8])
|
||||
|
||||
k2 *= c2
|
||||
k2 = rotl(k2, 33)
|
||||
k2 *= c1
|
||||
h2 ^= k2
|
||||
|
||||
fallthrough
|
||||
case 8:
|
||||
k1 ^= block(tail[7]) << 56
|
||||
fallthrough
|
||||
case 7:
|
||||
k1 ^= block(tail[6]) << 48
|
||||
fallthrough
|
||||
case 6:
|
||||
k1 ^= block(tail[5]) << 40
|
||||
fallthrough
|
||||
case 5:
|
||||
k1 ^= block(tail[4]) << 32
|
||||
fallthrough
|
||||
case 4:
|
||||
k1 ^= block(tail[3]) << 24
|
||||
fallthrough
|
||||
case 3:
|
||||
k1 ^= block(tail[2]) << 16
|
||||
fallthrough
|
||||
case 2:
|
||||
k1 ^= block(tail[1]) << 8
|
||||
fallthrough
|
||||
case 1:
|
||||
k1 ^= block(tail[0])
|
||||
|
||||
k1 *= c1
|
||||
k1 = rotl(k1, 31)
|
||||
k1 *= c2
|
||||
h1 ^= k1
|
||||
}
|
||||
|
||||
h1 ^= int64(length)
|
||||
h2 ^= int64(length)
|
||||
|
||||
h1 += h2
|
||||
h2 += h1
|
||||
|
||||
h1 = fmix(h1)
|
||||
h2 = fmix(h2)
|
||||
|
||||
h1 += h2
|
||||
// the following is extraneous since h2 is discarded
|
||||
// h2 += h1
|
||||
|
||||
return h1
|
||||
}
|
12
vendor/github.com/gocql/gocql/internal/murmur/murmur_appengine.go
generated
vendored
Normal file
12
vendor/github.com/gocql/gocql/internal/murmur/murmur_appengine.go
generated
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build appengine || s390x
|
||||
// +build appengine s390x
|
||||
|
||||
package murmur
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
func getBlock(data []byte, n int) (int64, int64) {
|
||||
k1 := int64(binary.LittleEndian.Uint64(data[n*16:]))
|
||||
k2 := int64(binary.LittleEndian.Uint64(data[(n*16)+8:]))
|
||||
return k1, k2
|
||||
}
|
16
vendor/github.com/gocql/gocql/internal/murmur/murmur_unsafe.go
generated
vendored
Normal file
16
vendor/github.com/gocql/gocql/internal/murmur/murmur_unsafe.go
generated
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !appengine && !s390x
|
||||
// +build !appengine,!s390x
|
||||
|
||||
package murmur
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func getBlock(data []byte, n int) (int64, int64) {
|
||||
block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
|
||||
|
||||
k1 := block[0]
|
||||
k2 := block[1]
|
||||
return k1, k2
|
||||
}
|
146
vendor/github.com/gocql/gocql/internal/streams/streams.go
generated
vendored
Normal file
146
vendor/github.com/gocql/gocql/internal/streams/streams.go
generated
vendored
Normal file
@@ -0,0 +1,146 @@
|
||||
package streams
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const bucketBits = 64
|
||||
|
||||
// IDGenerator tracks and allocates streams which are in use.
|
||||
type IDGenerator struct {
|
||||
NumStreams int
|
||||
inuseStreams int32
|
||||
numBuckets uint32
|
||||
|
||||
// streams is a bitset where each bit represents a stream, a 1 implies in use
|
||||
streams []uint64
|
||||
offset uint32
|
||||
}
|
||||
|
||||
func New(protocol int) *IDGenerator {
|
||||
maxStreams := 128
|
||||
if protocol > 2 {
|
||||
maxStreams = 32768
|
||||
}
|
||||
return NewLimited(maxStreams)
|
||||
}
|
||||
|
||||
func NewLimited(maxStreams int) *IDGenerator {
|
||||
// Round up maxStreams to a nearest
|
||||
// multiple of 64
|
||||
maxStreams = ((maxStreams + 63) / 64) * 64
|
||||
|
||||
buckets := maxStreams / 64
|
||||
// reserve stream 0
|
||||
streams := make([]uint64, buckets)
|
||||
streams[0] = 1 << 63
|
||||
|
||||
return &IDGenerator{
|
||||
NumStreams: maxStreams,
|
||||
streams: streams,
|
||||
numBuckets: uint32(buckets),
|
||||
offset: uint32(buckets) - 1,
|
||||
}
|
||||
}
|
||||
|
||||
func streamFromBucket(bucket, streamInBucket int) int {
|
||||
return (bucket * bucketBits) + streamInBucket
|
||||
}
|
||||
|
||||
func (s *IDGenerator) GetStream() (int, bool) {
|
||||
// Reduce collisions by offsetting the starting point
|
||||
offset := atomic.AddUint32(&s.offset, 1)
|
||||
|
||||
for i := uint32(0); i < s.numBuckets; i++ {
|
||||
pos := int((i + offset) % s.numBuckets)
|
||||
|
||||
bucket := atomic.LoadUint64(&s.streams[pos])
|
||||
if bucket == math.MaxUint64 {
|
||||
// all streams in use
|
||||
continue
|
||||
}
|
||||
|
||||
for j := 0; j < bucketBits; j++ {
|
||||
mask := uint64(1 << streamOffset(j))
|
||||
for bucket&mask == 0 {
|
||||
if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
|
||||
atomic.AddInt32(&s.inuseStreams, 1)
|
||||
return streamFromBucket(int(pos), j), true
|
||||
}
|
||||
bucket = atomic.LoadUint64(&s.streams[pos])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func bitfmt(b uint64) string {
|
||||
return strconv.FormatUint(b, 16)
|
||||
}
|
||||
|
||||
// returns the bucket offset of a given stream
|
||||
func bucketOffset(i int) int {
|
||||
return i / bucketBits
|
||||
}
|
||||
|
||||
func streamOffset(stream int) uint64 {
|
||||
return bucketBits - uint64(stream%bucketBits) - 1
|
||||
}
|
||||
|
||||
func isSet(bits uint64, stream int) bool {
|
||||
return bits>>streamOffset(stream)&1 == 1
|
||||
}
|
||||
|
||||
func (s *IDGenerator) isSet(stream int) bool {
|
||||
bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
|
||||
return isSet(bits, stream)
|
||||
}
|
||||
|
||||
func (s *IDGenerator) String() string {
|
||||
size := s.numBuckets * (bucketBits + 1)
|
||||
buf := make([]byte, 0, size)
|
||||
for i := 0; i < int(s.numBuckets); i++ {
|
||||
bits := atomic.LoadUint64(&s.streams[i])
|
||||
buf = append(buf, bitfmt(bits)...)
|
||||
buf = append(buf, ' ')
|
||||
}
|
||||
return string(buf[: size-1 : size-1])
|
||||
}
|
||||
|
||||
func (s *IDGenerator) Clear(stream int) (inuse bool) {
|
||||
offset := bucketOffset(stream)
|
||||
bucket := atomic.LoadUint64(&s.streams[offset])
|
||||
|
||||
mask := uint64(1) << streamOffset(stream)
|
||||
if bucket&mask != mask {
|
||||
// already cleared
|
||||
return false
|
||||
}
|
||||
|
||||
for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
|
||||
bucket = atomic.LoadUint64(&s.streams[offset])
|
||||
if bucket&mask != mask {
|
||||
// already cleared
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this account for 0 stream being reserved
|
||||
if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
|
||||
// TODO(zariel): remove this
|
||||
panic("negative streams inuse")
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *IDGenerator) Available() int {
|
||||
return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
|
||||
}
|
||||
|
||||
func (s *IDGenerator) InUse() int {
|
||||
return int(atomic.LoadInt32(&s.inuseStreams))
|
||||
}
|
40
vendor/github.com/gocql/gocql/logger.go
generated
vendored
Normal file
40
vendor/github.com/gocql/gocql/logger.go
generated
vendored
Normal file
@@ -0,0 +1,40 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
type StdLogger interface {
|
||||
Print(v ...interface{})
|
||||
Printf(format string, v ...interface{})
|
||||
Println(v ...interface{})
|
||||
}
|
||||
|
||||
type nopLogger struct{}
|
||||
|
||||
func (n nopLogger) Print(_ ...interface{}) {}
|
||||
|
||||
func (n nopLogger) Printf(_ string, _ ...interface{}) {}
|
||||
|
||||
func (n nopLogger) Println(_ ...interface{}) {}
|
||||
|
||||
type testLogger struct {
|
||||
capture bytes.Buffer
|
||||
}
|
||||
|
||||
func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) }
|
||||
func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
|
||||
func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) }
|
||||
func (l *testLogger) String() string { return l.capture.String() }
|
||||
|
||||
type defaultLogger struct{}
|
||||
|
||||
func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) }
|
||||
func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) }
|
||||
func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) }
|
||||
|
||||
// Logger for logging messages.
|
||||
// Deprecated: Use ClusterConfig.Logger instead.
|
||||
var Logger StdLogger = &defaultLogger{}
|
1846
vendor/github.com/gocql/gocql/marshal.go
generated
vendored
Normal file
1846
vendor/github.com/gocql/gocql/marshal.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1466
vendor/github.com/gocql/gocql/metadata_cassandra.go
generated
vendored
Normal file
1466
vendor/github.com/gocql/gocql/metadata_cassandra.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1102
vendor/github.com/gocql/gocql/metadata_scylla.go
generated
vendored
Normal file
1102
vendor/github.com/gocql/gocql/metadata_scylla.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1326
vendor/github.com/gocql/gocql/policies.go
generated
vendored
Normal file
1326
vendor/github.com/gocql/gocql/policies.go
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
78
vendor/github.com/gocql/gocql/prepared_cache.go
generated
vendored
Normal file
78
vendor/github.com/gocql/gocql/prepared_cache.go
generated
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
|
||||
"github.com/gocql/gocql/internal/lru"
|
||||
)
|
||||
|
||||
const defaultMaxPreparedStmts = 1000
|
||||
|
||||
// preparedLRU is the prepared statement cache
|
||||
type preparedLRU struct {
|
||||
mu sync.Mutex
|
||||
lru *lru.Cache
|
||||
}
|
||||
|
||||
func (p *preparedLRU) clear() {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
for p.lru.Len() > 0 {
|
||||
p.lru.RemoveOldest()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *preparedLRU) add(key string, val *inflightPrepare) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.lru.Add(key, val)
|
||||
}
|
||||
|
||||
func (p *preparedLRU) remove(key string) bool {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
return p.lru.Remove(key)
|
||||
}
|
||||
|
||||
func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
val, ok := p.lru.Get(key)
|
||||
if ok {
|
||||
return val.(*inflightPrepare), true
|
||||
}
|
||||
|
||||
return fn(p.lru), false
|
||||
}
|
||||
|
||||
func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string {
|
||||
// TODO: we should just use a struct for the key in the map
|
||||
return hostID + keyspace + statement
|
||||
}
|
||||
|
||||
func (p *preparedLRU) evictPreparedID(key string, id []byte) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
val, ok := p.lru.Get(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ifp, ok := val.(*inflightPrepare)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ifp.done:
|
||||
if bytes.Equal(id, ifp.preparedStatment.id) {
|
||||
p.lru.Remove(key)
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
}
|
238
vendor/github.com/gocql/gocql/query_executor.go
generated
vendored
Normal file
238
vendor/github.com/gocql/gocql/query_executor.go
generated
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ExecutableQuery interface {
|
||||
borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine.
|
||||
releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error.
|
||||
execute(ctx context.Context, conn *Conn) *Iter
|
||||
attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
|
||||
retryPolicy() RetryPolicy
|
||||
speculativeExecutionPolicy() SpeculativeExecutionPolicy
|
||||
GetRoutingKey() ([]byte, error)
|
||||
Keyspace() string
|
||||
Table() string
|
||||
IsIdempotent() bool
|
||||
IsLWT() bool
|
||||
GetCustomPartitioner() Partitioner
|
||||
|
||||
withContext(context.Context) ExecutableQuery
|
||||
|
||||
RetryableQuery
|
||||
|
||||
GetSession() *Session
|
||||
}
|
||||
|
||||
type queryExecutor struct {
|
||||
pool *policyConnPool
|
||||
policy HostSelectionPolicy
|
||||
}
|
||||
|
||||
func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
|
||||
start := time.Now()
|
||||
iter := qry.execute(ctx, conn)
|
||||
end := time.Now()
|
||||
|
||||
qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
|
||||
|
||||
return iter
|
||||
}
|
||||
|
||||
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
|
||||
hostIter NextHost, results chan *Iter) *Iter {
|
||||
ticker := time.NewTicker(sp.Delay())
|
||||
defer ticker.Stop()
|
||||
|
||||
for i := 0; i < sp.Attempts(); i++ {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
|
||||
go q.run(ctx, qry, hostIter, results)
|
||||
case <-ctx.Done():
|
||||
return &Iter{err: ctx.Err()}
|
||||
case iter := <-results:
|
||||
return iter
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
|
||||
hostIter := q.policy.Pick(qry)
|
||||
|
||||
// check if the query is not marked as idempotent, if
|
||||
// it is, we force the policy to NonSpeculative
|
||||
sp := qry.speculativeExecutionPolicy()
|
||||
if !qry.IsIdempotent() || sp.Attempts() == 0 {
|
||||
return q.do(qry.Context(), qry, hostIter), nil
|
||||
}
|
||||
|
||||
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
|
||||
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
|
||||
var mu sync.Mutex
|
||||
origHostIter := hostIter
|
||||
hostIter = func() SelectedHost {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return origHostIter()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(qry.Context())
|
||||
defer cancel()
|
||||
|
||||
results := make(chan *Iter, 1)
|
||||
|
||||
// Launch the main execution
|
||||
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
|
||||
go q.run(ctx, qry, hostIter, results)
|
||||
|
||||
// The speculative executions are launched _in addition_ to the main
|
||||
// execution, on a timer. So Speculation{2} would make 3 executions running
|
||||
// in total.
|
||||
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
|
||||
return iter, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case iter := <-results:
|
||||
return iter, nil
|
||||
case <-ctx.Done():
|
||||
return &Iter{err: ctx.Err()}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
|
||||
rt := qry.retryPolicy()
|
||||
if rt == nil {
|
||||
rt = &SimpleRetryPolicy{3}
|
||||
}
|
||||
|
||||
lwtRT, isRTSupportsLWT := rt.(LWTRetryPolicy)
|
||||
|
||||
var getShouldRetry func(qry RetryableQuery) bool
|
||||
var getRetryType func(error) RetryType
|
||||
|
||||
if isRTSupportsLWT && qry.IsLWT() {
|
||||
getShouldRetry = lwtRT.AttemptLWT
|
||||
getRetryType = lwtRT.GetRetryTypeLWT
|
||||
} else {
|
||||
getShouldRetry = rt.Attempt
|
||||
getRetryType = rt.GetRetryType
|
||||
}
|
||||
|
||||
var potentiallyExecuted bool
|
||||
|
||||
execute := func(qry ExecutableQuery, selectedHost SelectedHost) (iter *Iter, retry RetryType) {
|
||||
host := selectedHost.Info()
|
||||
if host == nil || !host.IsUp() {
|
||||
return &Iter{
|
||||
err: &QueryError{
|
||||
err: ErrHostDown,
|
||||
potentiallyExecuted: potentiallyExecuted,
|
||||
},
|
||||
}, RetryNextHost
|
||||
}
|
||||
pool, ok := q.pool.getPool(host)
|
||||
if !ok {
|
||||
return &Iter{
|
||||
err: &QueryError{
|
||||
err: ErrNoPool,
|
||||
potentiallyExecuted: potentiallyExecuted,
|
||||
},
|
||||
}, RetryNextHost
|
||||
}
|
||||
conn := pool.Pick(selectedHost.Token(), qry)
|
||||
if conn == nil {
|
||||
return &Iter{
|
||||
err: &QueryError{
|
||||
err: ErrNoConnectionsInPool,
|
||||
potentiallyExecuted: potentiallyExecuted,
|
||||
},
|
||||
}, RetryNextHost
|
||||
}
|
||||
iter = q.attemptQuery(ctx, qry, conn)
|
||||
iter.host = selectedHost.Info()
|
||||
// Update host
|
||||
if iter.err == nil {
|
||||
return iter, RetryType(255)
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(iter.err, context.Canceled),
|
||||
errors.Is(iter.err, context.DeadlineExceeded):
|
||||
selectedHost.Mark(nil)
|
||||
potentiallyExecuted = true
|
||||
retry = Rethrow
|
||||
default:
|
||||
selectedHost.Mark(iter.err)
|
||||
retry = RetryType(255) // Don't enforce retry and get it from retry policy
|
||||
}
|
||||
|
||||
var qErr *QueryError
|
||||
if errors.As(iter.err, &qErr) {
|
||||
potentiallyExecuted = potentiallyExecuted && qErr.PotentiallyExecuted()
|
||||
qErr.potentiallyExecuted = potentiallyExecuted
|
||||
qErr.isIdempotent = qry.IsIdempotent()
|
||||
iter.err = qErr
|
||||
} else {
|
||||
iter.err = &QueryError{
|
||||
err: iter.err,
|
||||
potentiallyExecuted: potentiallyExecuted,
|
||||
isIdempotent: qry.IsIdempotent(),
|
||||
}
|
||||
}
|
||||
return iter, retry
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
selectedHost := hostIter()
|
||||
for selectedHost != nil {
|
||||
iter, retryType := execute(qry, selectedHost)
|
||||
if iter.err == nil {
|
||||
return iter
|
||||
}
|
||||
lastErr = iter.err
|
||||
|
||||
// Exit if retry policy decides to not retry anymore
|
||||
if retryType == RetryType(255) {
|
||||
if !getShouldRetry(qry) {
|
||||
return iter
|
||||
}
|
||||
retryType = getRetryType(iter.err)
|
||||
}
|
||||
|
||||
// If query is unsuccessful, check the error with RetryPolicy to retry
|
||||
switch retryType {
|
||||
case Retry:
|
||||
// retry on the same host
|
||||
continue
|
||||
case Rethrow, Ignore:
|
||||
return iter
|
||||
case RetryNextHost:
|
||||
// retry on the next host
|
||||
selectedHost = hostIter()
|
||||
continue
|
||||
default:
|
||||
// Undefined? Return nil and error, this will panic in the requester
|
||||
return &Iter{err: ErrUnknownRetryType}
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return &Iter{err: lastErr}
|
||||
}
|
||||
return &Iter{err: ErrNoConnections}
|
||||
}
|
||||
|
||||
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
|
||||
select {
|
||||
case results <- q.do(ctx, qry, hostIter):
|
||||
case <-ctx.Done():
|
||||
}
|
||||
qry.releaseAfterExecution()
|
||||
}
|
544
vendor/github.com/gocql/gocql/recreate.go
generated
vendored
Normal file
544
vendor/github.com/gocql/gocql/recreate.go
generated
vendored
Normal file
@@ -0,0 +1,544 @@
|
||||
//go:build !cassandra
|
||||
// +build !cassandra
|
||||
|
||||
// Copyright (C) 2017 ScyllaDB
|
||||
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// ToCQL returns a CQL query that ca be used to recreate keyspace with all
|
||||
// user defined types, tables, indexes, functions, aggregates and views associated
|
||||
// with this keyspace.
|
||||
func (km *KeyspaceMetadata) ToCQL() (string, error) {
|
||||
// Be aware that `CreateStmts` is not only a cache for ToCQL,
|
||||
// but it also can be populated from response to `DESCRIBE KEYSPACE %s WITH INTERNALS`
|
||||
if len(km.CreateStmts) != 0 {
|
||||
return km.CreateStmts, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
|
||||
if err := km.keyspaceToCQL(&sb); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
sortedTypes := km.typesSortedTopologically()
|
||||
for _, tm := range sortedTypes {
|
||||
if err := km.userTypeToCQL(&sb, tm); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
for _, tm := range km.Tables {
|
||||
if err := km.tableToCQL(&sb, km.Name, tm); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
for _, im := range km.Indexes {
|
||||
if err := km.indexToCQL(&sb, im); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
for _, fm := range km.Functions {
|
||||
if err := km.functionToCQL(&sb, km.Name, fm); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
for _, am := range km.Aggregates {
|
||||
if err := km.aggregateToCQL(&sb, am); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
for _, vm := range km.Views {
|
||||
if err := km.viewToCQL(&sb, vm); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
km.CreateStmts = sb.String()
|
||||
return km.CreateStmts, nil
|
||||
}
|
||||
|
||||
func (km *KeyspaceMetadata) typesSortedTopologically() []*TypeMetadata {
|
||||
sortedTypes := make([]*TypeMetadata, 0, len(km.Types))
|
||||
for _, tm := range km.Types {
|
||||
sortedTypes = append(sortedTypes, tm)
|
||||
}
|
||||
sort.Slice(sortedTypes, func(i, j int) bool {
|
||||
for _, ft := range sortedTypes[j].FieldTypes {
|
||||
if strings.Contains(ft, sortedTypes[i].Name) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
return sortedTypes
|
||||
}
|
||||
|
||||
var tableCQLTemplate = template.Must(template.New("table").
|
||||
Funcs(map[string]interface{}{
|
||||
"escape": cqlHelpers.escape,
|
||||
"tableColumnToCQL": cqlHelpers.tableColumnToCQL,
|
||||
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
|
||||
}).
|
||||
Parse(`
|
||||
CREATE TABLE {{ .KeyspaceName }}.{{ .Tm.Name }} (
|
||||
{{ tableColumnToCQL .Tm }}
|
||||
) WITH {{ tablePropertiesToCQL .Tm.ClusteringColumns .Tm.Options .Tm.Flags .Tm.Extensions }};
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) tableToCQL(w io.Writer, kn string, tm *TableMetadata) error {
|
||||
if err := tableCQLTemplate.Execute(w, map[string]interface{}{
|
||||
"Tm": tm,
|
||||
"KeyspaceName": kn,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var functionTemplate = template.Must(template.New("functions").
|
||||
Funcs(map[string]interface{}{
|
||||
"escape": cqlHelpers.escape,
|
||||
"zip": cqlHelpers.zip,
|
||||
"stripFrozen": cqlHelpers.stripFrozen,
|
||||
}).
|
||||
Parse(`
|
||||
CREATE FUNCTION {{ .keyspaceName }}.{{ .fm.Name }} (
|
||||
{{- range $i, $args := zip .fm.ArgumentNames .fm.ArgumentTypes }}
|
||||
{{- if ne $i 0 }}, {{ end }}
|
||||
{{- (index $args 0) }}
|
||||
{{ stripFrozen (index $args 1) }}
|
||||
{{- end -}})
|
||||
{{ if .fm.CalledOnNullInput }}CALLED{{ else }}RETURNS NULL{{ end }} ON NULL INPUT
|
||||
RETURNS {{ .fm.ReturnType }}
|
||||
LANGUAGE {{ .fm.Language }}
|
||||
AS $${{ .fm.Body }}$$;
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) functionToCQL(w io.Writer, keyspaceName string, fm *FunctionMetadata) error {
|
||||
if err := functionTemplate.Execute(w, map[string]interface{}{
|
||||
"fm": fm,
|
||||
"keyspaceName": keyspaceName,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var viewTemplate = template.Must(template.New("views").
|
||||
Funcs(map[string]interface{}{
|
||||
"zip": cqlHelpers.zip,
|
||||
"partitionKeyString": cqlHelpers.partitionKeyString,
|
||||
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
|
||||
}).
|
||||
Parse(`
|
||||
CREATE MATERIALIZED VIEW {{ .vm.KeyspaceName }}.{{ .vm.ViewName }} AS
|
||||
SELECT {{ if .vm.IncludeAllColumns }}*{{ else }}
|
||||
{{- range $i, $col := .vm.OrderedColumns }}
|
||||
{{- if ne $i 0 }}, {{ end }}
|
||||
{{ $col }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
FROM {{ .vm.KeyspaceName }}.{{ .vm.BaseTableName }}
|
||||
WHERE {{ .vm.WhereClause }}
|
||||
PRIMARY KEY ({{ partitionKeyString .vm.PartitionKey .vm.ClusteringColumns }})
|
||||
WITH {{ tablePropertiesToCQL .vm.ClusteringColumns .vm.Options .flags .vm.Extensions }};
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) viewToCQL(w io.Writer, vm *ViewMetadata) error {
|
||||
if err := viewTemplate.Execute(w, map[string]interface{}{
|
||||
"vm": vm,
|
||||
"flags": []string{},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var aggregatesTemplate = template.Must(template.New("aggregate").
|
||||
Funcs(map[string]interface{}{
|
||||
"stripFrozen": cqlHelpers.stripFrozen,
|
||||
}).
|
||||
Parse(`
|
||||
CREATE AGGREGATE {{ .Keyspace }}.{{ .Name }}(
|
||||
{{- range $i, $arg := .ArgumentTypes }}
|
||||
{{- if ne $i 0 }}, {{ end }}
|
||||
{{ stripFrozen $arg }}
|
||||
{{- end -}})
|
||||
SFUNC {{ .StateFunc.Name }}
|
||||
STYPE {{ stripFrozen .StateType }}
|
||||
{{- if ne .FinalFunc.Name "" }}
|
||||
FINALFUNC {{ .FinalFunc.Name }}
|
||||
{{- end -}}
|
||||
{{- if ne .InitCond "" }}
|
||||
INITCOND {{ .InitCond }}
|
||||
{{- end -}}
|
||||
;
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) aggregateToCQL(w io.Writer, am *AggregateMetadata) error {
|
||||
if err := aggregatesTemplate.Execute(w, am); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var typeCQLTemplate = template.Must(template.New("types").
|
||||
Funcs(map[string]interface{}{
|
||||
"zip": cqlHelpers.zip,
|
||||
}).
|
||||
Parse(`
|
||||
CREATE TYPE {{ .Keyspace }}.{{ .Name }} (
|
||||
{{- range $i, $fields := zip .FieldNames .FieldTypes }} {{- if ne $i 0 }},{{ end }}
|
||||
{{ index $fields 0 }} {{ index $fields 1 }}
|
||||
{{- end }}
|
||||
);
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) userTypeToCQL(w io.Writer, tm *TypeMetadata) error {
|
||||
if err := typeCQLTemplate.Execute(w, tm); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (km *KeyspaceMetadata) indexToCQL(w io.Writer, im *IndexMetadata) error {
|
||||
// Scylla doesn't support any custom indexes
|
||||
if im.Kind == IndexKindCustom {
|
||||
return nil
|
||||
}
|
||||
|
||||
options := im.Options
|
||||
indexTarget := options["target"]
|
||||
|
||||
// secondary index
|
||||
si := struct {
|
||||
ClusteringKeys []string `json:"ck"`
|
||||
PartitionKeys []string `json:"pk"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal([]byte(indexTarget), &si); err == nil {
|
||||
indexTarget = fmt.Sprintf("(%s), %s",
|
||||
strings.Join(si.PartitionKeys, ","),
|
||||
strings.Join(si.ClusteringKeys, ","),
|
||||
)
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(w, "\nCREATE INDEX %s ON %s.%s (%s);\n",
|
||||
im.Name,
|
||||
im.KeyspaceName,
|
||||
im.TableName,
|
||||
indexTarget,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var keyspaceCQLTemplate = template.Must(template.New("keyspace").
|
||||
Funcs(map[string]interface{}{
|
||||
"escape": cqlHelpers.escape,
|
||||
"fixStrategy": cqlHelpers.fixStrategy,
|
||||
}).
|
||||
Parse(`CREATE KEYSPACE {{ .Name }} WITH replication = {
|
||||
'class': {{ escape ( fixStrategy .StrategyClass) }}
|
||||
{{- range $key, $value := .StrategyOptions }},
|
||||
{{ escape $key }}: {{ escape $value }}
|
||||
{{- end }}
|
||||
}{{ if not .DurableWrites }} AND durable_writes = 'false'{{ end }};
|
||||
`))
|
||||
|
||||
func (km *KeyspaceMetadata) keyspaceToCQL(w io.Writer) error {
|
||||
if err := keyspaceCQLTemplate.Execute(w, km); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func contains(in []string, v string) bool {
|
||||
for _, e := range in {
|
||||
if e == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type toCQLHelpers struct{}
|
||||
|
||||
var cqlHelpers = toCQLHelpers{}
|
||||
|
||||
func (h toCQLHelpers) zip(a []string, b []string) [][]string {
|
||||
m := make([][]string, len(a))
|
||||
for i := range a {
|
||||
m[i] = []string{a[i], b[i]}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) escape(e interface{}) string {
|
||||
switch v := e.(type) {
|
||||
case int, float64:
|
||||
return fmt.Sprint(v)
|
||||
case bool:
|
||||
if v {
|
||||
return "true"
|
||||
}
|
||||
return "false"
|
||||
case string:
|
||||
return "'" + strings.ReplaceAll(v, "'", "''") + "'"
|
||||
case []byte:
|
||||
return string(v)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) stripFrozen(v string) string {
|
||||
return strings.TrimSuffix(strings.TrimPrefix(v, "frozen<"), ">")
|
||||
}
|
||||
func (h toCQLHelpers) fixStrategy(v string) string {
|
||||
return strings.TrimPrefix(v, "org.apache.cassandra.locator.")
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) fixQuote(v string) string {
|
||||
return strings.ReplaceAll(v, `"`, `'`)
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) tableOptionsToCQL(ops TableMetadataOptions) ([]string, error) {
|
||||
opts := map[string]interface{}{
|
||||
"bloom_filter_fp_chance": ops.BloomFilterFpChance,
|
||||
"comment": ops.Comment,
|
||||
"crc_check_chance": ops.CrcCheckChance,
|
||||
"default_time_to_live": ops.DefaultTimeToLive,
|
||||
"gc_grace_seconds": ops.GcGraceSeconds,
|
||||
"max_index_interval": ops.MaxIndexInterval,
|
||||
"memtable_flush_period_in_ms": ops.MemtableFlushPeriodInMs,
|
||||
"min_index_interval": ops.MinIndexInterval,
|
||||
"speculative_retry": ops.SpeculativeRetry,
|
||||
}
|
||||
|
||||
var err error
|
||||
opts["caching"], err = json.Marshal(ops.Caching)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts["compaction"], err = json.Marshal(ops.Compaction)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts["compression"], err = json.Marshal(ops.Compression)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cdc, err := json.Marshal(ops.CDC)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(cdc) != "null" {
|
||||
opts["cdc"] = cdc
|
||||
}
|
||||
|
||||
if ops.InMemory {
|
||||
opts["in_memory"] = ops.InMemory
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(opts))
|
||||
for key, opt := range opts {
|
||||
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(opt))))
|
||||
}
|
||||
|
||||
sort.Strings(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) tableExtensionsToCQL(extensions map[string]interface{}) ([]string, error) {
|
||||
exts := map[string]interface{}{}
|
||||
|
||||
if blob, ok := extensions["scylla_encryption_options"]; ok {
|
||||
encOpts := &scyllaEncryptionOptions{}
|
||||
if err := encOpts.UnmarshalBinary(blob.([]byte)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var err error
|
||||
exts["scylla_encryption_options"], err = json.Marshal(encOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(exts))
|
||||
for key, ext := range exts {
|
||||
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(ext))))
|
||||
}
|
||||
|
||||
sort.Strings(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) tablePropertiesToCQL(cks []*ColumnMetadata, opts TableMetadataOptions,
|
||||
flags []string, extensions map[string]interface{}) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var properties []string
|
||||
|
||||
compactStorage := len(flags) > 0 && (contains(flags, TableFlagDense) ||
|
||||
contains(flags, TableFlagSuper) ||
|
||||
!contains(flags, TableFlagCompound))
|
||||
|
||||
if compactStorage {
|
||||
properties = append(properties, "COMPACT STORAGE")
|
||||
}
|
||||
|
||||
if len(cks) > 0 {
|
||||
var inner []string
|
||||
for _, col := range cks {
|
||||
inner = append(inner, fmt.Sprintf("%s %s", col.Name, col.ClusteringOrder))
|
||||
}
|
||||
properties = append(properties, fmt.Sprintf("CLUSTERING ORDER BY (%s)", strings.Join(inner, ", ")))
|
||||
}
|
||||
|
||||
options, err := h.tableOptionsToCQL(opts)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
properties = append(properties, options...)
|
||||
|
||||
exts, err := h.tableExtensionsToCQL(extensions)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
properties = append(properties, exts...)
|
||||
|
||||
sb.WriteString(strings.Join(properties, "\n AND "))
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) tableColumnToCQL(tm *TableMetadata) string {
|
||||
var sb strings.Builder
|
||||
|
||||
var columns []string
|
||||
for _, cn := range tm.OrderedColumns {
|
||||
cm := tm.Columns[cn]
|
||||
column := fmt.Sprintf("%s %s", cn, cm.Type)
|
||||
if cm.Kind == ColumnStatic {
|
||||
column += " static"
|
||||
}
|
||||
columns = append(columns, column)
|
||||
}
|
||||
if len(tm.PartitionKey) == 1 && len(tm.ClusteringColumns) == 0 && len(columns) > 0 {
|
||||
columns[0] += " PRIMARY KEY"
|
||||
}
|
||||
|
||||
sb.WriteString(strings.Join(columns, ",\n "))
|
||||
|
||||
if len(tm.PartitionKey) > 1 || len(tm.ClusteringColumns) > 0 {
|
||||
sb.WriteString(",\n PRIMARY KEY (")
|
||||
sb.WriteString(h.partitionKeyString(tm.PartitionKey, tm.ClusteringColumns))
|
||||
sb.WriteRune(')')
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func (h toCQLHelpers) partitionKeyString(pks, cks []*ColumnMetadata) string {
|
||||
var sb strings.Builder
|
||||
|
||||
if len(pks) > 1 {
|
||||
sb.WriteRune('(')
|
||||
for i, pk := range pks {
|
||||
if i != 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(pk.Name)
|
||||
}
|
||||
sb.WriteRune(')')
|
||||
} else {
|
||||
sb.WriteString(pks[0].Name)
|
||||
}
|
||||
|
||||
if len(cks) > 0 {
|
||||
sb.WriteString(", ")
|
||||
for i, ck := range cks {
|
||||
if i != 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(ck.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
type scyllaEncryptionOptions struct {
|
||||
CipherAlgorithm string `json:"cipher_algorithm"`
|
||||
SecretKeyStrength int `json:"secret_key_strength"`
|
||||
KeyProvider string `json:"key_provider"`
|
||||
SecretKeyFile string `json:"secret_key_file"`
|
||||
}
|
||||
|
||||
// UnmarshalBinary deserializes blob into scyllaEncryptionOptions.
|
||||
// Format:
|
||||
// - 4 bytes - size of KV map
|
||||
// Size times:
|
||||
// - 4 bytes - length of key
|
||||
// - len_of_key bytes - key
|
||||
// - 4 bytes - length of value
|
||||
// - len_of_value bytes - value
|
||||
func (enc *scyllaEncryptionOptions) UnmarshalBinary(data []byte) error {
|
||||
size := binary.LittleEndian.Uint32(data[0:4])
|
||||
|
||||
m := make(map[string]string, size)
|
||||
|
||||
off := uint32(4)
|
||||
for i := uint32(0); i < size; i++ {
|
||||
keyLen := binary.LittleEndian.Uint32(data[off : off+4])
|
||||
off += 4
|
||||
|
||||
key := string(data[off : off+keyLen])
|
||||
off += keyLen
|
||||
|
||||
valueLen := binary.LittleEndian.Uint32(data[off : off+4])
|
||||
off += 4
|
||||
|
||||
value := string(data[off : off+valueLen])
|
||||
off += valueLen
|
||||
|
||||
m[key] = value
|
||||
}
|
||||
|
||||
enc.CipherAlgorithm = m["cipher_algorithm"]
|
||||
enc.KeyProvider = m["key_provider"]
|
||||
enc.SecretKeyFile = m["secret_key_file"]
|
||||
if secretKeyStrength, ok := m["secret_key_strength"]; ok {
|
||||
sks, err := strconv.Atoi(secretKeyStrength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
enc.SecretKeyStrength = sks
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
275
vendor/github.com/gocql/gocql/ring_describer.go
generated
vendored
Normal file
275
vendor/github.com/gocql/gocql/ring_describer.go
generated
vendored
Normal file
@@ -0,0 +1,275 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Polls system.peers at a specific interval to find new hosts
|
||||
type ringDescriber struct {
|
||||
control controlConnection
|
||||
cfg *ClusterConfig
|
||||
logger StdLogger
|
||||
mu sync.RWMutex
|
||||
prevHosts []*HostInfo
|
||||
prevPartitioner string
|
||||
|
||||
// hosts are the set of all hosts in the cassandra ring that we know of.
|
||||
// key of map is host_id.
|
||||
hosts map[string]*HostInfo
|
||||
// hostIPToUUID maps host native address to host_id.
|
||||
hostIPToUUID map[string]string
|
||||
}
|
||||
|
||||
func (r *ringDescriber) setControlConn(c controlConnection) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.control = c
|
||||
}
|
||||
|
||||
// Ask the control node for the local host info
|
||||
func (r *ringDescriber) getLocalHostInfo(conn ConnInterface) (*HostInfo, error) {
|
||||
iter := conn.querySystem(context.TODO(), qrySystemLocal)
|
||||
|
||||
if iter == nil {
|
||||
return nil, errNoControl
|
||||
}
|
||||
|
||||
host, err := hostInfoFromIter(iter, nil, r.cfg.Port, r.cfg.translateAddressPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not retrieve local host info: %w", err)
|
||||
}
|
||||
return host, nil
|
||||
}
|
||||
|
||||
// Ask the control node for host info on all it's known peers
|
||||
func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo, c ConnInterface) ([]*HostInfo, error) {
|
||||
var iter *Iter
|
||||
if c.getIsSchemaV2() {
|
||||
iter = c.querySystem(context.TODO(), qrySystemPeersV2)
|
||||
} else {
|
||||
iter = c.querySystem(context.TODO(), qrySystemPeers)
|
||||
}
|
||||
|
||||
if iter == nil {
|
||||
return nil, errNoControl
|
||||
}
|
||||
|
||||
rows, err := iter.SliceMap()
|
||||
if err != nil {
|
||||
// TODO(zariel): make typed error
|
||||
return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
|
||||
}
|
||||
|
||||
return getPeersFromQuerySystemPeers(rows, r.cfg.Port, r.cfg.translateAddressPort, r.logger)
|
||||
}
|
||||
|
||||
func getPeersFromQuerySystemPeers(querySystemPeerRows []map[string]interface{}, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) ([]*HostInfo, error) {
|
||||
var peers []*HostInfo
|
||||
|
||||
for _, row := range querySystemPeerRows {
|
||||
// extract all available info about the peer
|
||||
host, err := hostInfoFromMap(row, &HostInfo{port: port}, translateAddressPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if !isValidPeer(host) {
|
||||
// If it's not a valid peer
|
||||
logger.Printf("Found invalid peer '%s' "+
|
||||
"Likely due to a gossip or snitch issue, this host will be ignored", host)
|
||||
continue
|
||||
} else if isZeroToken(host) {
|
||||
continue
|
||||
}
|
||||
|
||||
peers = append(peers, host)
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// Return true if the host is a valid peer
|
||||
func isValidPeer(host *HostInfo) bool {
|
||||
return !(len(host.RPCAddress()) == 0 ||
|
||||
host.hostId == "" ||
|
||||
host.dataCenter == "" ||
|
||||
host.rack == "")
|
||||
}
|
||||
|
||||
func isZeroToken(host *HostInfo) bool {
|
||||
return len(host.tokens) == 0
|
||||
}
|
||||
|
||||
// GetHostsFromSystem returns a list of hosts found via queries to system.local and system.peers
|
||||
func (r *ringDescriber) GetHostsFromSystem() ([]*HostInfo, string, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if r.control == nil {
|
||||
return r.prevHosts, r.prevPartitioner, errNoControl
|
||||
}
|
||||
|
||||
ch := r.control.getConn()
|
||||
localHost, err := r.getLocalHostInfo(ch.conn)
|
||||
if err != nil {
|
||||
return r.prevHosts, r.prevPartitioner, err
|
||||
}
|
||||
|
||||
peerHosts, err := r.getClusterPeerInfo(localHost, ch.conn)
|
||||
if err != nil {
|
||||
return r.prevHosts, r.prevPartitioner, err
|
||||
}
|
||||
|
||||
var hosts []*HostInfo
|
||||
if !isZeroToken(localHost) {
|
||||
hosts = []*HostInfo{localHost}
|
||||
}
|
||||
hosts = append(hosts, peerHosts...)
|
||||
|
||||
var partitioner string
|
||||
if len(hosts) > 0 {
|
||||
partitioner = hosts[0].Partitioner()
|
||||
}
|
||||
|
||||
r.prevHosts = hosts
|
||||
r.prevPartitioner = partitioner
|
||||
|
||||
return hosts, partitioner, nil
|
||||
}
|
||||
|
||||
// Given an ip/port return HostInfo for the specified ip/port
|
||||
func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) {
|
||||
var host *HostInfo
|
||||
for _, table := range []string{"system.peers", "system.local"} {
|
||||
ch := r.control.getConn()
|
||||
var iter *Iter
|
||||
if ch.host.HostID() == hostID.String() {
|
||||
host = ch.host
|
||||
iter = nil
|
||||
}
|
||||
|
||||
if table == "system.peers" {
|
||||
if ch.conn.getIsSchemaV2() {
|
||||
iter = ch.conn.querySystem(context.TODO(), qrySystemPeersV2)
|
||||
} else {
|
||||
iter = ch.conn.querySystem(context.TODO(), qrySystemPeers)
|
||||
}
|
||||
} else {
|
||||
iter = ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table))
|
||||
}
|
||||
|
||||
if iter != nil {
|
||||
rows, err := iter.SliceMap()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
h, err := hostInfoFromMap(row, &HostInfo{port: r.cfg.Port}, r.cfg.translateAddressPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.HostID() == hostID.String() {
|
||||
host = h
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if host == nil {
|
||||
return nil, errors.New("unable to fetch host info: invalid control connection")
|
||||
} else if host.invalidConnectAddr() {
|
||||
return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", host.connectAddress, host)
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (r *ringDescriber) getHostByIP(ip string) (*HostInfo, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
hi, ok := r.hostIPToUUID[ip]
|
||||
return r.hosts[hi], ok
|
||||
}
|
||||
|
||||
func (r *ringDescriber) getHost(hostID string) *HostInfo {
|
||||
r.mu.RLock()
|
||||
host := r.hosts[hostID]
|
||||
r.mu.RUnlock()
|
||||
return host
|
||||
}
|
||||
|
||||
func (r *ringDescriber) getHostsList() []*HostInfo {
|
||||
r.mu.RLock()
|
||||
hosts := make([]*HostInfo, 0, len(r.hosts))
|
||||
for _, host := range r.hosts {
|
||||
hosts = append(hosts, host)
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
return hosts
|
||||
}
|
||||
|
||||
func (r *ringDescriber) getHostsMap() map[string]*HostInfo {
|
||||
r.mu.RLock()
|
||||
hosts := make(map[string]*HostInfo, len(r.hosts))
|
||||
for k, v := range r.hosts {
|
||||
hosts[k] = v
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
return hosts
|
||||
}
|
||||
|
||||
func (r *ringDescriber) addOrUpdate(host *HostInfo) *HostInfo {
|
||||
if existingHost, ok := r.addHostIfMissing(host); ok {
|
||||
existingHost.update(host)
|
||||
host = existingHost
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func (r *ringDescriber) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
|
||||
if host.invalidConnectAddr() {
|
||||
panic(fmt.Sprintf("invalid host: %v", host))
|
||||
}
|
||||
hostID := host.HostID()
|
||||
|
||||
r.mu.Lock()
|
||||
if r.hosts == nil {
|
||||
r.hosts = make(map[string]*HostInfo)
|
||||
}
|
||||
if r.hostIPToUUID == nil {
|
||||
r.hostIPToUUID = make(map[string]string)
|
||||
}
|
||||
|
||||
existing, ok := r.hosts[hostID]
|
||||
if !ok {
|
||||
r.hosts[hostID] = host
|
||||
r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID
|
||||
existing = host
|
||||
}
|
||||
r.mu.Unlock()
|
||||
return existing, ok
|
||||
}
|
||||
|
||||
func (r *ringDescriber) removeHost(hostID string) bool {
|
||||
r.mu.Lock()
|
||||
if r.hosts == nil {
|
||||
r.hosts = make(map[string]*HostInfo)
|
||||
}
|
||||
if r.hostIPToUUID == nil {
|
||||
r.hostIPToUUID = make(map[string]string)
|
||||
}
|
||||
|
||||
h, ok := r.hosts[hostID]
|
||||
if ok {
|
||||
delete(r.hostIPToUUID, h.nodeToNodeAddress().String())
|
||||
}
|
||||
delete(r.hosts, hostID)
|
||||
r.mu.Unlock()
|
||||
return ok
|
||||
}
|
874
vendor/github.com/gocql/gocql/scylla.go
generated
vendored
Normal file
874
vendor/github.com/gocql/gocql/scylla.go
generated
vendored
Normal file
@@ -0,0 +1,874 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// scyllaSupported represents Scylla connection options as sent in SUPPORTED
|
||||
// frame.
|
||||
// FIXME: Should also follow `cqlProtocolExtension` interface.
|
||||
type scyllaSupported struct {
|
||||
shard int
|
||||
nrShards int
|
||||
msbIgnore uint64
|
||||
partitioner string
|
||||
shardingAlgorithm string
|
||||
shardAwarePort uint16
|
||||
shardAwarePortSSL uint16
|
||||
lwtFlagMask int
|
||||
}
|
||||
|
||||
// CQL Protocol extension interface for Scylla.
|
||||
// Each extension is identified by a name and defines a way to serialize itself
|
||||
// in STARTUP message payload.
|
||||
type cqlProtocolExtension interface {
|
||||
name() string
|
||||
serialize() map[string]string
|
||||
}
|
||||
|
||||
func findCQLProtoExtByName(exts []cqlProtocolExtension, name string) cqlProtocolExtension {
|
||||
for i := range exts {
|
||||
if exts[i].name() == name {
|
||||
return exts[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Top-level keys used for serialization/deserialization of CQL protocol
|
||||
// extensions in SUPPORTED/STARTUP messages.
|
||||
// Each key identifies a single extension.
|
||||
const (
|
||||
lwtAddMetadataMarkKey = "SCYLLA_LWT_ADD_METADATA_MARK"
|
||||
rateLimitError = "SCYLLA_RATE_LIMIT_ERROR"
|
||||
tabletsRoutingV1 = "TABLETS_ROUTING_V1"
|
||||
)
|
||||
|
||||
// "tabletsRoutingV1" CQL Protocol Extension.
|
||||
// This extension, if enabled (properly negotiated), allows Scylla server
|
||||
// to send a tablet information in `custom_payload`.
|
||||
//
|
||||
// Implements cqlProtocolExtension interface.
|
||||
type tabletsRoutingV1Ext struct {
|
||||
}
|
||||
|
||||
var _ cqlProtocolExtension = &tabletsRoutingV1Ext{}
|
||||
|
||||
// Factory function to deserialize and create an `tabletsRoutingV1Ext` instance
|
||||
// from SUPPORTED message payload.
|
||||
func newTabletsRoutingV1Ext(supported map[string][]string) *tabletsRoutingV1Ext {
|
||||
if _, found := supported[tabletsRoutingV1]; found {
|
||||
return &tabletsRoutingV1Ext{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *tabletsRoutingV1Ext) serialize() map[string]string {
|
||||
return map[string]string{
|
||||
tabletsRoutingV1: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *tabletsRoutingV1Ext) name() string {
|
||||
return tabletsRoutingV1
|
||||
}
|
||||
|
||||
// "Rate limit" CQL Protocol Extension.
|
||||
// This extension, if enabled (properly negotiated), allows Scylla server
|
||||
// to send a special kind of error.
|
||||
//
|
||||
// Implements cqlProtocolExtension interface.
|
||||
type rateLimitExt struct {
|
||||
rateLimitErrorCode int
|
||||
}
|
||||
|
||||
var _ cqlProtocolExtension = &rateLimitExt{}
|
||||
|
||||
// Factory function to deserialize and create an `rateLimitExt` instance
|
||||
// from SUPPORTED message payload.
|
||||
func newRateLimitExt(supported map[string][]string) *rateLimitExt {
|
||||
const rateLimitErrorCode = "ERROR_CODE"
|
||||
|
||||
if v, found := supported[rateLimitError]; found {
|
||||
for i := range v {
|
||||
splitVal := strings.Split(v[i], "=")
|
||||
if splitVal[0] == rateLimitErrorCode {
|
||||
var (
|
||||
err error
|
||||
errorCode int
|
||||
)
|
||||
if errorCode, err = strconv.Atoi(splitVal[1]); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", rateLimitErrorCode, splitVal[1], err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &rateLimitExt{
|
||||
rateLimitErrorCode: errorCode,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *rateLimitExt) serialize() map[string]string {
|
||||
return map[string]string{
|
||||
rateLimitError: "",
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *rateLimitExt) name() string {
|
||||
return rateLimitError
|
||||
}
|
||||
|
||||
// "LWT prepared statements metadata mark" CQL Protocol Extension.
|
||||
// This extension, if enabled (properly negotiated), allows Scylla server
|
||||
// to set a special bit in prepared statements metadata, which would indicate
|
||||
// whether the statement at hand is LWT statement or not.
|
||||
//
|
||||
// This is further used to consistently choose primary replicas in a predefined
|
||||
// order for these queries, which can reduce contention over hot keys and thus
|
||||
// increase LWT performance.
|
||||
//
|
||||
// Implements cqlProtocolExtension interface.
|
||||
type lwtAddMetadataMarkExt struct {
|
||||
lwtOptMetaBitMask int
|
||||
}
|
||||
|
||||
var _ cqlProtocolExtension = &lwtAddMetadataMarkExt{}
|
||||
|
||||
// Factory function to deserialize and create an `lwtAddMetadataMarkExt` instance
|
||||
// from SUPPORTED message payload.
|
||||
func newLwtAddMetaMarkExt(supported map[string][]string) *lwtAddMetadataMarkExt {
|
||||
const lwtOptMetaBitMaskKey = "LWT_OPTIMIZATION_META_BIT_MASK"
|
||||
|
||||
if v, found := supported[lwtAddMetadataMarkKey]; found {
|
||||
for i := range v {
|
||||
splitVal := strings.Split(v[i], "=")
|
||||
if splitVal[0] == lwtOptMetaBitMaskKey {
|
||||
var (
|
||||
err error
|
||||
bitMask int
|
||||
)
|
||||
if bitMask, err = strconv.Atoi(splitVal[1]); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", lwtOptMetaBitMaskKey, splitVal[1], err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &lwtAddMetadataMarkExt{
|
||||
lwtOptMetaBitMask: bitMask,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ext *lwtAddMetadataMarkExt) serialize() map[string]string {
|
||||
return map[string]string{
|
||||
lwtAddMetadataMarkKey: fmt.Sprintf("LWT_OPTIMIZATION_META_BIT_MASK=%d", ext.lwtOptMetaBitMask),
|
||||
}
|
||||
}
|
||||
|
||||
func (ext *lwtAddMetadataMarkExt) name() string {
|
||||
return lwtAddMetadataMarkKey
|
||||
}
|
||||
|
||||
func parseSupported(supported map[string][]string) scyllaSupported {
|
||||
const (
|
||||
scyllaShard = "SCYLLA_SHARD"
|
||||
scyllaNrShards = "SCYLLA_NR_SHARDS"
|
||||
scyllaPartitioner = "SCYLLA_PARTITIONER"
|
||||
scyllaShardingAlgorithm = "SCYLLA_SHARDING_ALGORITHM"
|
||||
scyllaShardingIgnoreMSB = "SCYLLA_SHARDING_IGNORE_MSB"
|
||||
scyllaShardAwarePort = "SCYLLA_SHARD_AWARE_PORT"
|
||||
scyllaShardAwarePortSSL = "SCYLLA_SHARD_AWARE_PORT_SSL"
|
||||
)
|
||||
|
||||
var (
|
||||
si scyllaSupported
|
||||
err error
|
||||
)
|
||||
|
||||
if s, ok := supported[scyllaShard]; ok {
|
||||
if si.shard, err = strconv.Atoi(s[0]); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShard, s, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s, ok := supported[scyllaNrShards]; ok {
|
||||
if si.nrShards, err = strconv.Atoi(s[0]); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaNrShards, s, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if s, ok := supported[scyllaShardingIgnoreMSB]; ok {
|
||||
if si.msbIgnore, err = strconv.ParseUint(s[0], 10, 64); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardingIgnoreMSB, s, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s, ok := supported[scyllaPartitioner]; ok {
|
||||
si.partitioner = s[0]
|
||||
}
|
||||
if s, ok := supported[scyllaShardingAlgorithm]; ok {
|
||||
si.shardingAlgorithm = s[0]
|
||||
}
|
||||
if s, ok := supported[scyllaShardAwarePort]; ok {
|
||||
if shardAwarePort, err := strconv.ParseUint(s[0], 10, 16); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePort, s, err)
|
||||
}
|
||||
} else {
|
||||
si.shardAwarePort = uint16(shardAwarePort)
|
||||
}
|
||||
}
|
||||
if s, ok := supported[scyllaShardAwarePortSSL]; ok {
|
||||
if shardAwarePortSSL, err := strconv.ParseUint(s[0], 10, 16); err != nil {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePortSSL, s, err)
|
||||
}
|
||||
} else {
|
||||
si.shardAwarePortSSL = uint16(shardAwarePortSSL)
|
||||
}
|
||||
}
|
||||
|
||||
if si.partitioner != "org.apache.cassandra.dht.Murmur3Partitioner" || si.shardingAlgorithm != "biased-token-round-robin" || si.nrShards == 0 || si.msbIgnore == 0 {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: unsupported sharding configuration, partitioner=%s, algorithm=%s, no_shards=%d, msb_ignore=%d",
|
||||
si.partitioner, si.shardingAlgorithm, si.nrShards, si.msbIgnore)
|
||||
}
|
||||
return scyllaSupported{}
|
||||
}
|
||||
|
||||
return si
|
||||
}
|
||||
|
||||
func parseCQLProtocolExtensions(supported map[string][]string) []cqlProtocolExtension {
|
||||
exts := []cqlProtocolExtension{}
|
||||
|
||||
lwtExt := newLwtAddMetaMarkExt(supported)
|
||||
if lwtExt != nil {
|
||||
exts = append(exts, lwtExt)
|
||||
}
|
||||
|
||||
rateLimitExt := newRateLimitExt(supported)
|
||||
if rateLimitExt != nil {
|
||||
exts = append(exts, rateLimitExt)
|
||||
}
|
||||
|
||||
tabletsExt := newTabletsRoutingV1Ext(supported)
|
||||
if tabletsExt != nil {
|
||||
exts = append(exts, tabletsExt)
|
||||
}
|
||||
|
||||
return exts
|
||||
}
|
||||
|
||||
// isScyllaConn checks if conn is suitable for scyllaConnPicker.
|
||||
func (conn *Conn) isScyllaConn() bool {
|
||||
return conn.getScyllaSupported().nrShards != 0
|
||||
}
|
||||
|
||||
// scyllaConnPicker is a specialised ConnPicker that selects connections based
|
||||
// on token trying to get connection to a shard containing the given token.
|
||||
// A list of excess connections is maintained to allow for lazy closing of
|
||||
// connections to already opened shards. Keeping excess connections open helps
|
||||
// reaching equilibrium faster since the likelihood of hitting the same shard
|
||||
// decreases with the number of connections to the shard.
|
||||
//
|
||||
// scyllaConnPicker keeps track of the details about the shard-aware port.
|
||||
// When used as a Dialer, it connects to the shard-aware port instead of the
|
||||
// regular port (if the node supports it). For each subsequent connection
|
||||
// it tries to make, the shard that it aims to connect to is chosen
|
||||
// in a round-robin fashion.
|
||||
type scyllaConnPicker struct {
|
||||
address string
|
||||
hostId string
|
||||
shardAwareAddress string
|
||||
conns []*Conn
|
||||
excessConns []*Conn
|
||||
nrConns int
|
||||
nrShards int
|
||||
msbIgnore uint64
|
||||
pos uint64
|
||||
lastAttemptedShard int
|
||||
shardAwarePortDisabled bool
|
||||
|
||||
// Used to disable new connections to the shard-aware port temporarily
|
||||
disableShardAwarePortUntil *atomic.Value
|
||||
}
|
||||
|
||||
func newScyllaConnPicker(conn *Conn) *scyllaConnPicker {
|
||||
addr := conn.Address()
|
||||
hostId := conn.host.hostId
|
||||
|
||||
if conn.scyllaSupported.nrShards == 0 {
|
||||
panic(fmt.Sprintf("scylla: %s not a sharded connection", addr))
|
||||
}
|
||||
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s new conn picker sharding options %+v", addr, conn.scyllaSupported)
|
||||
}
|
||||
|
||||
var shardAwarePort uint16
|
||||
if conn.session.connCfg.tlsConfig != nil {
|
||||
shardAwarePort = conn.scyllaSupported.shardAwarePortSSL
|
||||
} else {
|
||||
shardAwarePort = conn.scyllaSupported.shardAwarePort
|
||||
}
|
||||
|
||||
var shardAwareAddress string
|
||||
if shardAwarePort != 0 {
|
||||
tIP, tPort := conn.session.cfg.translateAddressPort(conn.host.UntranslatedConnectAddress(), int(shardAwarePort))
|
||||
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
|
||||
}
|
||||
|
||||
return &scyllaConnPicker{
|
||||
address: addr,
|
||||
hostId: hostId,
|
||||
shardAwareAddress: shardAwareAddress,
|
||||
nrShards: conn.scyllaSupported.nrShards,
|
||||
msbIgnore: conn.scyllaSupported.msbIgnore,
|
||||
lastAttemptedShard: 0,
|
||||
shardAwarePortDisabled: conn.session.cfg.DisableShardAwarePort,
|
||||
|
||||
disableShardAwarePortUntil: new(atomic.Value),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) Pick(t Token, qry ExecutableQuery) *Conn {
|
||||
if len(p.conns) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if t == nil {
|
||||
return p.leastBusyConn()
|
||||
}
|
||||
|
||||
mmt, ok := t.(int64Token)
|
||||
// double check if that's murmur3 token
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
idx := -1
|
||||
|
||||
for _, conn := range p.conns {
|
||||
if conn == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if qry != nil && conn.isTabletSupported() {
|
||||
tablets := conn.session.getTablets()
|
||||
|
||||
// Search for tablets with Keyspace and Table from the Query
|
||||
l, r := tablets.findTablets(qry.Keyspace(), qry.Table())
|
||||
|
||||
if l != -1 {
|
||||
tablet := tablets.findTabletForToken(mmt, l, r)
|
||||
|
||||
for _, replica := range tablet.replicas {
|
||||
if replica.hostId.String() == p.hostId {
|
||||
idx = replica.shardId
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if idx == -1 {
|
||||
idx = p.shardOf(mmt)
|
||||
}
|
||||
|
||||
if c := p.conns[idx]; c != nil {
|
||||
// We have this shard's connection
|
||||
// so let's give it to the caller.
|
||||
// But only if it's not loaded too much and load is well distributed.
|
||||
if qry != nil && qry.IsLWT() {
|
||||
return c
|
||||
}
|
||||
return p.maybeReplaceWithLessBusyConnection(c)
|
||||
}
|
||||
return p.leastBusyConn()
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) maybeReplaceWithLessBusyConnection(c *Conn) *Conn {
|
||||
if !isHeavyLoaded(c) {
|
||||
return c
|
||||
}
|
||||
alternative := p.leastBusyConn()
|
||||
if alternative == nil || alternative.AvailableStreams()*120 > c.AvailableStreams()*100 {
|
||||
return c
|
||||
} else {
|
||||
return alternative
|
||||
}
|
||||
}
|
||||
|
||||
func isHeavyLoaded(c *Conn) bool {
|
||||
return c.streams.NumStreams/2 > c.AvailableStreams()
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) leastBusyConn() *Conn {
|
||||
var (
|
||||
leastBusyConn *Conn
|
||||
streamsAvailable int
|
||||
)
|
||||
idx := int(atomic.AddUint64(&p.pos, 1))
|
||||
// find the conn which has the most available streams, this is racy
|
||||
for i := range p.conns {
|
||||
if conn := p.conns[(idx+i)%len(p.conns)]; conn != nil {
|
||||
if streams := conn.AvailableStreams(); streams > streamsAvailable {
|
||||
leastBusyConn = conn
|
||||
streamsAvailable = streams
|
||||
}
|
||||
}
|
||||
}
|
||||
return leastBusyConn
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) shardOf(token int64Token) int {
|
||||
shards := uint64(p.nrShards)
|
||||
z := uint64(token+math.MinInt64) << p.msbIgnore
|
||||
lo := z & 0xffffffff
|
||||
hi := (z >> 32) & 0xffffffff
|
||||
mul1 := lo * shards
|
||||
mul2 := hi * shards
|
||||
sum := (mul1 >> 32) + mul2
|
||||
return int(sum >> 32)
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) Put(conn *Conn) {
|
||||
var (
|
||||
nrShards = conn.scyllaSupported.nrShards
|
||||
shard = conn.scyllaSupported.shard
|
||||
)
|
||||
|
||||
if nrShards == 0 {
|
||||
panic(fmt.Sprintf("scylla: %s not a sharded connection", p.address))
|
||||
}
|
||||
|
||||
if nrShards != len(p.conns) {
|
||||
if nrShards != p.nrShards {
|
||||
panic(fmt.Sprintf("scylla: %s invalid number of shards", p.address))
|
||||
}
|
||||
conns := p.conns
|
||||
p.conns = make([]*Conn, nrShards, nrShards)
|
||||
copy(p.conns, conns)
|
||||
}
|
||||
|
||||
if c := p.conns[shard]; c != nil {
|
||||
if conn.addr == p.shardAwareAddress {
|
||||
// A connection made to the shard-aware port resulted in duplicate
|
||||
// connection to the same shard being made. Because this is never
|
||||
// intentional, it suggests that a NAT or AddressTranslator
|
||||
// changes the source port along the way, therefore we can't trust
|
||||
// the shard-aware port to return connection to the shard
|
||||
// that we requested. Fall back to non-shard-aware port for some time.
|
||||
Logger.Printf(
|
||||
"scylla: %s connection to shard-aware address %s resulted in wrong shard being assigned; please check that you are not behind a NAT or AddressTranslater which changes source ports; falling back to non-shard-aware port for %v",
|
||||
p.address,
|
||||
p.shardAwareAddress,
|
||||
scyllaShardAwarePortFallbackDuration,
|
||||
)
|
||||
until := time.Now().Add(scyllaShardAwarePortFallbackDuration)
|
||||
p.disableShardAwarePortUntil.Store(until)
|
||||
|
||||
// Connections to shard-aware port do not influence how shards
|
||||
// are chosen for the non-shard-aware port, therefore it can be
|
||||
// closed immediately
|
||||
closeConns(conn)
|
||||
} else {
|
||||
p.excessConns = append(p.excessConns, conn)
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s put shard %d excess connection total: %d missing: %d excess: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns, len(p.excessConns))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
p.conns[shard] = conn
|
||||
p.nrConns++
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s put shard %d connection total: %d missing: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns)
|
||||
}
|
||||
}
|
||||
|
||||
if p.shouldCloseExcessConns() {
|
||||
p.closeExcessConns()
|
||||
}
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) shouldCloseExcessConns() bool {
|
||||
const maxExcessConnsFactor = 10
|
||||
|
||||
if p.nrConns >= p.nrShards {
|
||||
return true
|
||||
}
|
||||
return len(p.excessConns) > maxExcessConnsFactor*p.nrShards
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) Remove(conn *Conn) {
|
||||
shard := conn.scyllaSupported.shard
|
||||
|
||||
if conn.scyllaSupported.nrShards == 0 {
|
||||
// It is possible for Remove to be called before the connection is added to the pool.
|
||||
// Ignoring these connections here is safe.
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s has unknown sharding state, ignoring it", p.address)
|
||||
}
|
||||
return
|
||||
}
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s remove shard %d connection", p.address, shard)
|
||||
}
|
||||
|
||||
if p.conns[shard] != nil {
|
||||
p.conns[shard] = nil
|
||||
p.nrConns--
|
||||
}
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) InFlight() int {
|
||||
result := 0
|
||||
for _, conn := range p.conns {
|
||||
if conn != nil {
|
||||
result = result + (conn.streams.InUse())
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) Size() (int, int) {
|
||||
return p.nrConns, p.nrShards - p.nrConns
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) Close() {
|
||||
p.closeConns()
|
||||
p.closeExcessConns()
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) closeConns() {
|
||||
if len(p.conns) == 0 {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s no connections to close", p.address)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conns := p.conns
|
||||
p.conns = nil
|
||||
p.nrConns = 0
|
||||
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s closing %d connections", p.address, len(conns))
|
||||
}
|
||||
go closeConns(conns...)
|
||||
}
|
||||
|
||||
func (p *scyllaConnPicker) closeExcessConns() {
|
||||
if len(p.excessConns) == 0 {
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s no excess connections to close", p.address)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
conns := p.excessConns
|
||||
p.excessConns = nil
|
||||
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: %s closing %d excess connections", p.address, len(conns))
|
||||
}
|
||||
go closeConns(conns...)
|
||||
}
|
||||
|
||||
// Closing must be done outside of hostConnPool lock. If holding a lock
|
||||
// a deadlock can occur when closing one of the connections returns error on close.
|
||||
// See scylladb/gocql#53.
|
||||
func closeConns(conns ...*Conn) {
|
||||
for _, conn := range conns {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NextShard returns the shardID to connect to.
|
||||
// nrShard specifies how many shards the host has.
|
||||
// If nrShards is zero, the caller shouldn't use shard-aware port.
|
||||
func (p *scyllaConnPicker) NextShard() (shardID, nrShards int) {
|
||||
if p.shardAwarePortDisabled {
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
disableUntil, _ := p.disableShardAwarePortUntil.Load().(time.Time)
|
||||
if time.Now().Before(disableUntil) {
|
||||
// There is suspicion that the shard-aware-port is not reachable
|
||||
// or misconfigured, fall back to the non-shard-aware port
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// Find the shard without a connection
|
||||
// It's important to start counting from 1 here because we want
|
||||
// to consider the next shard after the previously attempted one
|
||||
for i := 1; i <= p.nrShards; i++ {
|
||||
shardID := (p.lastAttemptedShard + i) % p.nrShards
|
||||
if p.conns == nil || p.conns[shardID] == nil {
|
||||
p.lastAttemptedShard = shardID
|
||||
return shardID, p.nrShards
|
||||
}
|
||||
}
|
||||
|
||||
// We did not find an unallocated shard
|
||||
// We will dial the non-shard-aware port
|
||||
return 0, 0
|
||||
}
|
||||
|
||||
// ShardDialer is like HostDialer but is shard-aware.
|
||||
// If the driver wants to connect to a specific shard, it will call DialShard,
|
||||
// otherwise it will call DialHost.
|
||||
type ShardDialer interface {
|
||||
HostDialer
|
||||
|
||||
// DialShard establishes a connection to the specified shard ID out of nrShards.
|
||||
// The returned connection must be directly usable for CQL protocol,
|
||||
// specifically DialShard is responsible also for setting up the TLS session if needed.
|
||||
DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error)
|
||||
}
|
||||
|
||||
// A dialer which dials a particular shard
|
||||
type scyllaDialer struct {
|
||||
dialer Dialer
|
||||
logger StdLogger
|
||||
tlsConfig *tls.Config
|
||||
cfg *ClusterConfig
|
||||
}
|
||||
|
||||
const scyllaShardAwarePortFallbackDuration time.Duration = 5 * time.Minute
|
||||
|
||||
func (sd *scyllaDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
|
||||
ip := host.ConnectAddress()
|
||||
port := host.Port()
|
||||
|
||||
if !validIpAddr(ip) {
|
||||
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
|
||||
} else if port == 0 {
|
||||
return nil, fmt.Errorf("host missing port: %v", port)
|
||||
}
|
||||
|
||||
addr := host.HostnameAndPort()
|
||||
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
|
||||
}
|
||||
|
||||
func (sd *scyllaDialer) DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error) {
|
||||
ip := host.ConnectAddress()
|
||||
port := host.Port()
|
||||
|
||||
if !validIpAddr(ip) {
|
||||
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
|
||||
} else if port == 0 {
|
||||
return nil, fmt.Errorf("host missing port: %v", port)
|
||||
}
|
||||
|
||||
iter := newScyllaPortIterator(shardID, nrShards)
|
||||
|
||||
addr := host.HostnameAndPort()
|
||||
|
||||
var shardAwarePort uint16
|
||||
if sd.tlsConfig != nil {
|
||||
shardAwarePort = host.ScyllaShardAwarePortTLS()
|
||||
} else {
|
||||
shardAwarePort = host.ScyllaShardAwarePort()
|
||||
}
|
||||
|
||||
var shardAwareAddress string
|
||||
if shardAwarePort != 0 {
|
||||
tIP, tPort := sd.cfg.translateAddressPort(host.UntranslatedConnectAddress(), int(shardAwarePort))
|
||||
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
|
||||
}
|
||||
|
||||
if gocqlDebug {
|
||||
sd.logger.Printf("scylla: connecting to shard %d", shardID)
|
||||
}
|
||||
|
||||
conn, err := sd.dialShardAware(ctx, addr, shardAwareAddress, iter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
|
||||
}
|
||||
|
||||
func (sd *scyllaDialer) dialShardAware(ctx context.Context, addr, shardAwareAddr string, iter *scyllaPortIterator) (net.Conn, error) {
|
||||
for {
|
||||
port, ok := iter.Next()
|
||||
if !ok {
|
||||
// We exhausted ports to connect from. Try the non-shard-aware port.
|
||||
return sd.dialer.DialContext(ctx, "tcp", addr)
|
||||
}
|
||||
|
||||
ctxWithPort := context.WithValue(ctx, scyllaSourcePortCtx{}, port)
|
||||
conn, err := sd.dialer.DialContext(ctxWithPort, "tcp", shardAwareAddr)
|
||||
|
||||
if isLocalAddrInUseErr(err) {
|
||||
// This indicates that the source port is already in use
|
||||
// We can immediately retry with another source port for this shard
|
||||
continue
|
||||
} else if err != nil {
|
||||
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
|
||||
if err == nil {
|
||||
// We failed to connect to the shard-aware port, but succeeded
|
||||
// in connecting to the non-shard-aware port. This might
|
||||
// indicate that the shard-aware port is just not reachable,
|
||||
// but we may also be unlucky and the node became reachable
|
||||
// just after we tried the first connection.
|
||||
// We can't avoid false positives here, so I'm putting it
|
||||
// behind a debug flag.
|
||||
if gocqlDebug {
|
||||
sd.logger.Printf(
|
||||
"scylla: %s couldn't connect to shard-aware address while the non-shard-aware address %s is available; this might be an issue with ",
|
||||
addr,
|
||||
shardAwareAddr,
|
||||
)
|
||||
}
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
}
|
||||
|
||||
// ErrScyllaSourcePortAlreadyInUse An error value which can returned from
|
||||
// a custom dialer implementation to indicate that the requested source port
|
||||
// to dial from is already in use
|
||||
var ErrScyllaSourcePortAlreadyInUse = errors.New("scylla: source port is already in use")
|
||||
|
||||
func isLocalAddrInUseErr(err error) bool {
|
||||
return errors.Is(err, syscall.EADDRINUSE) || errors.Is(err, ErrScyllaSourcePortAlreadyInUse)
|
||||
}
|
||||
|
||||
// ScyllaShardAwareDialer wraps a net.Dialer, but uses a source port specified by gocql when connecting.
|
||||
//
|
||||
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
|
||||
// a new connection by connecting from a specific source port. If you are using your own net.Dialer
|
||||
// in ClusterConfig, you can use ScyllaShardAwareDialer to "upgrade" it so that it connects
|
||||
// from the source port chosen by gocql.
|
||||
//
|
||||
// Please note that ScyllaShardAwareDialer overwrites the LocalAddr field in order to choose
|
||||
// the right source port for connection.
|
||||
type ScyllaShardAwareDialer struct {
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *ScyllaShardAwareDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
sourcePort := ScyllaGetSourcePort(ctx)
|
||||
if sourcePort == 0 {
|
||||
return d.Dialer.DialContext(ctx, network, addr)
|
||||
}
|
||||
dialerWithLocalAddr := d.Dialer
|
||||
dialerWithLocalAddr.LocalAddr, err = net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dialerWithLocalAddr.DialContext(ctx, network, addr)
|
||||
}
|
||||
|
||||
type scyllaPortIterator struct {
|
||||
currentPort int
|
||||
shardCount int
|
||||
}
|
||||
|
||||
const (
|
||||
scyllaPortBasedBalancingMin = 0x8000
|
||||
scyllaPortBasedBalancingMax = 0xFFFF
|
||||
)
|
||||
|
||||
func newScyllaPortIterator(shardID, shardCount int) *scyllaPortIterator {
|
||||
if shardCount == 0 {
|
||||
panic("shardCount cannot be 0")
|
||||
}
|
||||
|
||||
// Find the smallest port p such that p >= min and p % shardCount == shardID
|
||||
port := scyllaPortBasedBalancingMin - scyllaShardForSourcePort(scyllaPortBasedBalancingMin, shardCount) + shardID
|
||||
if port < scyllaPortBasedBalancingMin {
|
||||
port += shardCount
|
||||
}
|
||||
|
||||
return &scyllaPortIterator{
|
||||
currentPort: port,
|
||||
shardCount: shardCount,
|
||||
}
|
||||
}
|
||||
|
||||
func (spi *scyllaPortIterator) Next() (uint16, bool) {
|
||||
if spi == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
p := spi.currentPort
|
||||
|
||||
if p > scyllaPortBasedBalancingMax {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
spi.currentPort += spi.shardCount
|
||||
return uint16(p), true
|
||||
}
|
||||
|
||||
func scyllaShardForSourcePort(sourcePort uint16, shardCount int) int {
|
||||
return int(sourcePort) % shardCount
|
||||
}
|
||||
|
||||
type scyllaSourcePortCtx struct{}
|
||||
|
||||
// ScyllaGetSourcePort returns the source port that should be used when connecting to a node.
|
||||
//
|
||||
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
|
||||
// a new connection at the shard-aware port by connecting from a specific source port. Therefore,
|
||||
// if you are using a custom Dialer and your nodes expose shard-aware ports, your dialer should
|
||||
// use the source port specified by gocql.
|
||||
//
|
||||
// If this function returns 0, then your dialer can use any source port.
|
||||
//
|
||||
// If you aren't using a custom dialer, gocql will use a default one which uses appropriate source port.
|
||||
// If you are using net.Dialer, consider wrapping it in a gocql.ScyllaShardAwareDialer.
|
||||
func ScyllaGetSourcePort(ctx context.Context) uint16 {
|
||||
sourcePort, _ := ctx.Value(scyllaSourcePortCtx{}).(uint16)
|
||||
return sourcePort
|
||||
}
|
||||
|
||||
// Returns a partitioner specific to the table, or "nil"
|
||||
// if the cluster-global partitioner should be used
|
||||
func scyllaGetTablePartitioner(session *Session, keyspaceName, tableName string) (Partitioner, error) {
|
||||
isCdc, err := scyllaIsCdcTable(session, keyspaceName, tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isCdc {
|
||||
return scyllaCDCPartitioner{}, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
93
vendor/github.com/gocql/gocql/scylla_cdc.go
generated
vendored
Normal file
93
vendor/github.com/gocql/gocql/scylla_cdc.go
generated
vendored
Normal file
@@ -0,0 +1,93 @@
|
||||
package gocql
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// cdc partitioner
|
||||
|
||||
const (
|
||||
scyllaCDCPartitionerName = "CDCPartitioner"
|
||||
scyllaCDCPartitionerFullName = "com.scylladb.dht.CDCPartitioner"
|
||||
|
||||
scyllaCDCPartitionKeyLength = 16
|
||||
scyllaCDCVersionMask = 0x0F
|
||||
scyllaCDCMinSupportedVersion = 1
|
||||
scyllaCDCMaxSupportedVersion = 1
|
||||
|
||||
scyllaCDCMinToken = int64Token(math.MinInt64)
|
||||
scyllaCDCLogTableNameSuffix = "_scylla_cdc_log"
|
||||
scyllaCDCExtensionName = "cdc"
|
||||
)
|
||||
|
||||
type scyllaCDCPartitioner struct{}
|
||||
|
||||
var _ Partitioner = scyllaCDCPartitioner{}
|
||||
|
||||
func (p scyllaCDCPartitioner) Name() string {
|
||||
return scyllaCDCPartitionerName
|
||||
}
|
||||
|
||||
func (p scyllaCDCPartitioner) Hash(partitionKey []byte) Token {
|
||||
if len(partitionKey) < 8 {
|
||||
// The key is too short to extract any sensible token,
|
||||
// so return the min token instead
|
||||
if gocqlDebug {
|
||||
Logger.Printf("scylla: cdc partition key too short: %d < 8", len(partitionKey))
|
||||
}
|
||||
return scyllaCDCMinToken
|
||||
}
|
||||
|
||||
upperQword := binary.BigEndian.Uint64(partitionKey[0:])
|
||||
|
||||
if gocqlDebug {
|
||||
// In debug mode, do some more checks
|
||||
|
||||
if len(partitionKey) != scyllaCDCPartitionKeyLength {
|
||||
// The token has unrecognized format, but the first quadword
|
||||
// should be the token value that we want
|
||||
Logger.Printf("scylla: wrong size of cdc partition key: %d", len(partitionKey))
|
||||
}
|
||||
|
||||
lowerQword := binary.BigEndian.Uint64(partitionKey[8:])
|
||||
version := lowerQword & scyllaCDCVersionMask
|
||||
if version < scyllaCDCMinSupportedVersion || version > scyllaCDCMaxSupportedVersion {
|
||||
// We don't support this version yet,
|
||||
// the token may be wrong
|
||||
Logger.Printf(
|
||||
"scylla: unsupported version: %d is not in range [%d, %d]",
|
||||
version,
|
||||
scyllaCDCMinSupportedVersion,
|
||||
scyllaCDCMaxSupportedVersion,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return int64Token(upperQword)
|
||||
}
|
||||
|
||||
func (p scyllaCDCPartitioner) ParseString(str string) Token {
|
||||
return parseInt64Token(str)
|
||||
}
|
||||
|
||||
func scyllaIsCdcTable(session *Session, keyspaceName, tableName string) (bool, error) {
|
||||
if !strings.HasSuffix(tableName, scyllaCDCLogTableNameSuffix) {
|
||||
// Not a CDC table, use the default partitioner
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Get the table metadata to see if it has the cdc partitioner set
|
||||
keyspaceMeta, err := session.KeyspaceMetadata(keyspaceName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
tableMeta, ok := keyspaceMeta.Tables[tableName]
|
||||
if !ok {
|
||||
return false, ErrNoMetadata
|
||||
}
|
||||
|
||||
return tableMeta.Options.Partitioner == scyllaCDCPartitionerFullName, nil
|
||||
}
|
28
vendor/github.com/gocql/gocql/serialization/ascii/marshal.go
generated
vendored
Normal file
28
vendor/github.com/gocql/gocql/serialization/ascii/marshal.go
generated
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
package ascii
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Marshal(value interface{}) ([]byte, error) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return nil, nil
|
||||
case string:
|
||||
return EncString(v)
|
||||
case *string:
|
||||
return EncStringR(v)
|
||||
case []byte:
|
||||
return EncBytes(v)
|
||||
case *[]byte:
|
||||
return EncBytesR(v)
|
||||
default:
|
||||
// Custom types (type MyString string) can be serialized only via `reflect` package.
|
||||
// Later, when generic-based serialization is introduced we can do that via generics.
|
||||
rv := reflect.ValueOf(value)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return EncReflect(rv)
|
||||
}
|
||||
return EncReflectR(rv)
|
||||
}
|
||||
}
|
61
vendor/github.com/gocql/gocql/serialization/ascii/marshal_utils.go
generated
vendored
Normal file
61
vendor/github.com/gocql/gocql/serialization/ascii/marshal_utils.go
generated
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
package ascii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func EncString(v string) ([]byte, error) {
|
||||
return encString(v), nil
|
||||
}
|
||||
|
||||
func EncStringR(v *string) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return encString(*v), nil
|
||||
}
|
||||
|
||||
func EncBytes(v []byte) ([]byte, error) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func EncBytesR(v *[]byte) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return *v, nil
|
||||
}
|
||||
|
||||
func EncReflect(v reflect.Value) ([]byte, error) {
|
||||
switch v.Kind() {
|
||||
case reflect.String:
|
||||
return encString(v.String()), nil
|
||||
case reflect.Slice:
|
||||
if v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
return EncBytes(v.Bytes())
|
||||
case reflect.Struct:
|
||||
if v.Type().String() == "gocql.unsetColumn" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func EncReflectR(v reflect.Value) ([]byte, error) {
|
||||
if v.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return EncReflect(v.Elem())
|
||||
}
|
||||
|
||||
func encString(v string) []byte {
|
||||
if v == "" {
|
||||
return make([]byte, 0)
|
||||
}
|
||||
return []byte(v)
|
||||
}
|
33
vendor/github.com/gocql/gocql/serialization/ascii/unmarshal.go
generated
vendored
Normal file
33
vendor/github.com/gocql/gocql/serialization/ascii/unmarshal.go
generated
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
package ascii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case *string:
|
||||
return DecString(data, v)
|
||||
case **string:
|
||||
return DecStringR(data, v)
|
||||
case *[]byte:
|
||||
return DecBytes(data, v)
|
||||
case **[]byte:
|
||||
return DecBytesR(data, v)
|
||||
default:
|
||||
// Custom types (type MyString string) can be deserialized only via `reflect` package.
|
||||
// Later, when generic-based serialization is introduced we can do that via generics.
|
||||
rv := reflect.ValueOf(value)
|
||||
rt := rv.Type()
|
||||
if rt.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v)
|
||||
}
|
||||
if rt.Elem().Kind() != reflect.Ptr {
|
||||
return DecReflect(data, rv)
|
||||
}
|
||||
return DecReflectR(data, rv)
|
||||
}
|
||||
}
|
166
vendor/github.com/gocql/gocql/serialization/ascii/unmarshal_utils.go
generated
vendored
Normal file
166
vendor/github.com/gocql/gocql/serialization/ascii/unmarshal_utils.go
generated
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
package ascii
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func errInvalidData(p []byte) error {
|
||||
for i := range p {
|
||||
if p[i] > 127 {
|
||||
return fmt.Errorf("failed to unmarshal ascii: invalid charester %s", string(p[i]))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func errNilReference(v interface{}) error {
|
||||
return fmt.Errorf("failed to unmarshal ascii: can not unmarshal into nil reference(%T)(%[1]v)", v)
|
||||
}
|
||||
|
||||
func DecString(p []byte, v *string) error {
|
||||
if v == nil {
|
||||
return errNilReference(v)
|
||||
}
|
||||
*v = decString(p)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func DecStringR(p []byte, v **string) error {
|
||||
if v == nil {
|
||||
return errNilReference(v)
|
||||
}
|
||||
*v = decStringR(p)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func DecBytes(p []byte, v *[]byte) error {
|
||||
if v == nil {
|
||||
return errNilReference(v)
|
||||
}
|
||||
*v = decBytes(p)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func DecBytesR(p []byte, v **[]byte) error {
|
||||
if v == nil {
|
||||
return errNilReference(v)
|
||||
}
|
||||
*v = decBytesR(p)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func DecReflect(p []byte, v reflect.Value) error {
|
||||
if v.IsNil() {
|
||||
return errNilReference(v)
|
||||
}
|
||||
|
||||
switch v = v.Elem(); v.Kind() {
|
||||
case reflect.String:
|
||||
v.SetString(decString(p))
|
||||
case reflect.Slice:
|
||||
if v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
v.SetBytes(decBytes(p))
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func DecReflectR(p []byte, v reflect.Value) error {
|
||||
if v.IsNil() {
|
||||
return errNilReference(v)
|
||||
}
|
||||
|
||||
switch ev := v.Type().Elem().Elem(); ev.Kind() {
|
||||
case reflect.String:
|
||||
return decReflectStringR(p, v)
|
||||
case reflect.Slice:
|
||||
if ev.Elem().Kind() != reflect.Uint8 {
|
||||
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
return decReflectBytesR(p, v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func decReflectStringR(p []byte, v reflect.Value) error {
|
||||
if len(p) == 0 {
|
||||
if p == nil {
|
||||
v.Elem().Set(reflect.Zero(v.Elem().Type()))
|
||||
} else {
|
||||
v.Elem().Set(reflect.New(v.Type().Elem().Elem()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
val := reflect.New(v.Type().Elem().Elem())
|
||||
val.Elem().SetString(string(p))
|
||||
v.Elem().Set(val)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func decReflectBytesR(p []byte, v reflect.Value) error {
|
||||
if len(p) == 0 {
|
||||
if p == nil {
|
||||
v.Elem().Set(reflect.Zero(v.Elem().Type()))
|
||||
} else {
|
||||
val := reflect.New(v.Type().Elem().Elem())
|
||||
val.Elem().SetBytes(make([]byte, 0))
|
||||
v.Elem().Set(val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
tmp := make([]byte, len(p))
|
||||
copy(tmp, p)
|
||||
|
||||
val := reflect.New(v.Type().Elem().Elem())
|
||||
val.Elem().SetBytes(tmp)
|
||||
v.Elem().Set(val)
|
||||
return errInvalidData(p)
|
||||
}
|
||||
|
||||
func decString(p []byte) string {
|
||||
if len(p) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(p)
|
||||
}
|
||||
|
||||
func decStringR(p []byte) *string {
|
||||
if len(p) == 0 {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return new(string)
|
||||
}
|
||||
tmp := string(p)
|
||||
return &tmp
|
||||
}
|
||||
|
||||
func decBytes(p []byte) []byte {
|
||||
if len(p) == 0 {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
return make([]byte, 0)
|
||||
}
|
||||
tmp := make([]byte, len(p))
|
||||
copy(tmp, p)
|
||||
return tmp
|
||||
}
|
||||
|
||||
func decBytesR(p []byte) *[]byte {
|
||||
if len(p) == 0 {
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
tmp := make([]byte, 0)
|
||||
return &tmp
|
||||
}
|
||||
tmp := make([]byte, len(p))
|
||||
copy(tmp, p)
|
||||
return &tmp
|
||||
}
|
74
vendor/github.com/gocql/gocql/serialization/bigint/marshal.go
generated
vendored
Normal file
74
vendor/github.com/gocql/gocql/serialization/bigint/marshal.go
generated
vendored
Normal file
@@ -0,0 +1,74 @@
|
||||
package bigint
|
||||
|
||||
import (
|
||||
"math/big"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Marshal(value interface{}) ([]byte, error) {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return nil, nil
|
||||
case int8:
|
||||
return EncInt8(v)
|
||||
case int16:
|
||||
return EncInt16(v)
|
||||
case int32:
|
||||
return EncInt32(v)
|
||||
case int64:
|
||||
return EncInt64(v)
|
||||
case int:
|
||||
return EncInt(v)
|
||||
|
||||
case uint8:
|
||||
return EncUint8(v)
|
||||
case uint16:
|
||||
return EncUint16(v)
|
||||
case uint32:
|
||||
return EncUint32(v)
|
||||
case uint64:
|
||||
return EncUint64(v)
|
||||
case uint:
|
||||
return EncUint(v)
|
||||
|
||||
case big.Int:
|
||||
return EncBigInt(v)
|
||||
case string:
|
||||
return EncString(v)
|
||||
|
||||
case *int8:
|
||||
return EncInt8R(v)
|
||||
case *int16:
|
||||
return EncInt16R(v)
|
||||
case *int32:
|
||||
return EncInt32R(v)
|
||||
case *int64:
|
||||
return EncInt64R(v)
|
||||
case *int:
|
||||
return EncIntR(v)
|
||||
|
||||
case *uint8:
|
||||
return EncUint8R(v)
|
||||
case *uint16:
|
||||
return EncUint16R(v)
|
||||
case *uint32:
|
||||
return EncUint32R(v)
|
||||
case *uint64:
|
||||
return EncUint64R(v)
|
||||
case *uint:
|
||||
return EncUintR(v)
|
||||
|
||||
case *big.Int:
|
||||
return EncBigIntR(v)
|
||||
case *string:
|
||||
return EncStringR(v)
|
||||
default:
|
||||
// Custom types (type MyInt int) can be serialized only via `reflect` package.
|
||||
// Later, when generic-based serialization is introduced we can do that via generics.
|
||||
rv := reflect.TypeOf(value)
|
||||
if rv.Kind() != reflect.Ptr {
|
||||
return EncReflect(reflect.ValueOf(v))
|
||||
}
|
||||
return EncReflectR(reflect.ValueOf(v))
|
||||
}
|
||||
}
|
206
vendor/github.com/gocql/gocql/serialization/bigint/marshal_utils.go
generated
vendored
Normal file
206
vendor/github.com/gocql/gocql/serialization/bigint/marshal_utils.go
generated
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
package bigint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
maxBigInt = big.NewInt(math.MaxInt64)
|
||||
minBigInt = big.NewInt(math.MinInt64)
|
||||
)
|
||||
|
||||
func EncInt8(v int8) ([]byte, error) {
|
||||
if v < 0 {
|
||||
return []byte{255, 255, 255, 255, 255, 255, 255, byte(v)}, nil
|
||||
}
|
||||
return []byte{0, 0, 0, 0, 0, 0, 0, byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncInt8R(v *int8) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncInt8(*v)
|
||||
}
|
||||
|
||||
func EncInt16(v int16) ([]byte, error) {
|
||||
if v < 0 {
|
||||
return []byte{255, 255, 255, 255, 255, 255, byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncInt16R(v *int16) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncInt16(*v)
|
||||
}
|
||||
|
||||
func EncInt32(v int32) ([]byte, error) {
|
||||
if v < 0 {
|
||||
return []byte{255, 255, 255, 255, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncInt32R(v *int32) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncInt32(*v)
|
||||
}
|
||||
|
||||
func EncInt64(v int64) ([]byte, error) {
|
||||
return encInt64(v), nil
|
||||
}
|
||||
|
||||
func EncInt64R(v *int64) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncInt64(*v)
|
||||
}
|
||||
|
||||
func EncInt(v int) ([]byte, error) {
|
||||
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncIntR(v *int) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncInt(*v)
|
||||
}
|
||||
|
||||
func EncUint8(v uint8) ([]byte, error) {
|
||||
return []byte{0, 0, 0, 0, 0, 0, 0, v}, nil
|
||||
}
|
||||
|
||||
func EncUint8R(v *uint8) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncUint8(*v)
|
||||
}
|
||||
|
||||
func EncUint16(v uint16) ([]byte, error) {
|
||||
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncUint16R(v *uint16) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncUint16(*v)
|
||||
}
|
||||
|
||||
func EncUint32(v uint32) ([]byte, error) {
|
||||
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncUint32R(v *uint32) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncUint32(*v)
|
||||
}
|
||||
|
||||
func EncUint64(v uint64) ([]byte, error) {
|
||||
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncUint64R(v *uint64) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncUint64(*v)
|
||||
}
|
||||
|
||||
func EncUint(v uint) ([]byte, error) {
|
||||
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
|
||||
}
|
||||
|
||||
func EncUintR(v *uint) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncUint(*v)
|
||||
}
|
||||
|
||||
func EncBigInt(v big.Int) ([]byte, error) {
|
||||
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
|
||||
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
|
||||
}
|
||||
return encInt64(v.Int64()), nil
|
||||
}
|
||||
|
||||
func EncBigIntR(v *big.Int) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
|
||||
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
|
||||
}
|
||||
return encInt64(v.Int64()), nil
|
||||
}
|
||||
|
||||
func EncString(v string) ([]byte, error) {
|
||||
if v == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
n, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v, err)
|
||||
}
|
||||
return encInt64(n), nil
|
||||
}
|
||||
|
||||
func EncStringR(v *string) ([]byte, error) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return EncString(*v)
|
||||
}
|
||||
|
||||
func EncReflect(v reflect.Value) ([]byte, error) {
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
|
||||
return EncInt64(v.Int())
|
||||
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
|
||||
return EncUint64(v.Uint())
|
||||
case reflect.String:
|
||||
val := v.String()
|
||||
if val == "" {
|
||||
return nil, nil
|
||||
}
|
||||
n, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v.Interface(), err)
|
||||
}
|
||||
return encInt64(n), nil
|
||||
case reflect.Struct:
|
||||
if v.Type().String() == "gocql.unsetColumn" {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
func EncReflectR(v reflect.Value) ([]byte, error) {
|
||||
if v.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
return EncReflect(v.Elem())
|
||||
}
|
||||
|
||||
func encInt64(v int64) []byte {
|
||||
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
|
||||
}
|
81
vendor/github.com/gocql/gocql/serialization/bigint/unmarshal.go
generated
vendored
Normal file
81
vendor/github.com/gocql/gocql/serialization/bigint/unmarshal.go
generated
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
package bigint
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func Unmarshal(data []byte, value interface{}) error {
|
||||
switch v := value.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
|
||||
case *int8:
|
||||
return DecInt8(data, v)
|
||||
case *int16:
|
||||
return DecInt16(data, v)
|
||||
case *int32:
|
||||
return DecInt32(data, v)
|
||||
case *int64:
|
||||
return DecInt64(data, v)
|
||||
case *int:
|
||||
return DecInt(data, v)
|
||||
|
||||
case *uint8:
|
||||
return DecUint8(data, v)
|
||||
case *uint16:
|
||||
return DecUint16(data, v)
|
||||
case *uint32:
|
||||
return DecUint32(data, v)
|
||||
case *uint64:
|
||||
return DecUint64(data, v)
|
||||
case *uint:
|
||||
return DecUint(data, v)
|
||||
|
||||
case *big.Int:
|
||||
return DecBigInt(data, v)
|
||||
case *string:
|
||||
return DecString(data, v)
|
||||
|
||||
case **int8:
|
||||
return DecInt8R(data, v)
|
||||
case **int16:
|
||||
return DecInt16R(data, v)
|
||||
case **int32:
|
||||
return DecInt32R(data, v)
|
||||
case **int64:
|
||||
return DecInt64R(data, v)
|
||||
case **int:
|
||||
return DecIntR(data, v)
|
||||
|
||||
case **uint8:
|
||||
return DecUint8R(data, v)
|
||||
case **uint16:
|
||||
return DecUint16R(data, v)
|
||||
case **uint32:
|
||||
return DecUint32R(data, v)
|
||||
case **uint64:
|
||||
return DecUint64R(data, v)
|
||||
case **uint:
|
||||
return DecUintR(data, v)
|
||||
|
||||
case **big.Int:
|
||||
return DecBigIntR(data, v)
|
||||
case **string:
|
||||
return DecStringR(data, v)
|
||||
default:
|
||||
|
||||
// Custom types (type MyInt int) can be deserialized only via `reflect` package.
|
||||
// Later, when generic-based serialization is introduced we can do that via generics.
|
||||
rv := reflect.ValueOf(value)
|
||||
rt := rv.Type()
|
||||
if rt.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("failed to unmarshal bigint: unsupported value type (%T)(%[1]v)", value)
|
||||
}
|
||||
if rt.Elem().Kind() != reflect.Ptr {
|
||||
return DecReflect(data, rv)
|
||||
}
|
||||
return DecReflectR(data, rv)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user