Compare commits

...

44 Commits

Author SHA1 Message Date
5ac529ce26 modify example client to be API compliant 2025-05-25 13:32:29 -04:00
fa5d1e2689 add example client 2025-05-25 13:03:16 -04:00
04c83cccb9 add provisional CORS support 2025-05-25 13:00:29 -04:00
8e4a336510 make jwt direct 2025-05-25 11:31:42 -04:00
cb28c07ff4 implement JWT tokens 2025-05-25 11:22:55 -04:00
d5db656ca2 resolve with indexing (#1) 2025-05-18 21:38:12 -04:00
369d445637 run goinputs 2025-05-18 20:20:40 -04:00
5709bfd21d format go files (gofmt) 2025-05-18 20:15:40 -04:00
606e85d467 cleanup go mod 2025-05-18 20:15:10 -04:00
72c0188071 expand logging 2025-05-18 18:25:17 -04:00
028c084cdd start logging -- also clarified env checks 2025-05-18 15:08:17 -04:00
f2b046056b put sessions in db 2025-05-18 13:44:30 -04:00
985ed9943a implement basic configuration 2025-05-18 11:47:46 -04:00
f8a550883d implement Scylla database 2025-05-17 21:45:18 -04:00
252b49ae6a chore: add debug bin to gitignore 2025-04-08 21:32:51 -04:00
4b3d64c5cd add logout method 2025-04-08 21:30:42 -04:00
799bf784aa allow LoginCtx to handle anonymous users 2025-04-08 21:00:33 -04:00
14c78536de use login context with newmessage 2025-04-07 22:18:54 -04:00
32bfd109b9 add user identification 2025-04-07 22:08:46 -04:00
a578beea0d change fake db package name to make way for db being split out 2025-04-07 16:38:32 -04:00
3ac7e488af fix username in login field -- should be 'username' not 'name' 2025-04-07 16:26:12 -04:00
253f3dcdac update routes docs 2025-04-07 16:18:11 -04:00
b44d59bd21 chore: add bashInteractive to buildInputs in dev shell 2025-04-06 23:34:06 -04:00
a3c1ae5615 chore: update flake, move to nixos-unstable 2025-04-06 21:48:58 -04:00
a601f1ceec split auth functions to auth.go 2025-04-06 21:45:01 -04:00
ccc0a58f88 implement authentication 2025-04-06 21:36:03 -04:00
824ca781d4 create password auth skeleton 2025-03-30 20:54:41 -04:00
cd4ebf9dc7 update vendor files 2025-03-30 12:57:39 -04:00
71164ee85a cleanup dotfiles that should never have been tracked 2025-03-27 22:05:05 -04:00
841f7aa0de add direnv and vscode dotfiles to gitignore 2025-03-27 22:04:27 -04:00
3a968df15b update routes doc 2025-03-27 21:21:34 -04:00
e8d8e8d70b list all users method 2025-03-27 21:20:46 -04:00
25ee1d3299 methods for creating and getting users 2025-03-27 21:07:33 -04:00
9d7ad260f2 change user type from int to uuid string 2025-03-27 20:33:48 -04:00
732fbacc61 add license 2025-03-27 19:11:38 -04:00
9478437262 new test data 2025-03-27 19:07:06 -04:00
d8878eba09 update api docs 2025-03-27 18:48:18 -04:00
c55052ad5b create db update message function 2025-03-27 14:50:44 -04:00
a7466e5c77 implement message updates 2025-03-27 14:41:02 -04:00
b86ee0dac4 refactor to have all DB executive functions be in api/ 2025-03-27 14:12:10 -04:00
ec90717ad7 implement editing messages 2025-03-27 14:06:40 -04:00
3f417b0088 use formdata instead of url encoded valus 2025-03-27 13:28:03 -04:00
9870b79854 implement message deletion 2025-03-25 13:26:19 -04:00
02643c1197 use form data for new messages 2025-03-24 21:33:56 -04:00
273 changed files with 78878 additions and 2311 deletions

View File

@@ -1,19 +0,0 @@
#!/usr/bin/env bash
set -e
if [[ ! -d "/home/williamp/chatservice_concept" ]]; then
echo "Cannot find source directory; Did you move it?"
echo "(Looking for "/home/williamp/chatservice_concept")"
echo 'Cannot force reload with this script - use "direnv reload" manually and then try again'
exit 1
fi
# rebuild the cache forcefully
_nix_direnv_force_reload=1 direnv exec "/home/williamp/chatservice_concept" true
# Update the mtime for .envrc.
# This will cause direnv to reload again - but without re-building.
touch "/home/williamp/chatservice_concept/.envrc"
# Also update the timestamp of whatever profile_rc we have.
# This makes sure that we know we are up to date.
touch -r "/home/williamp/chatservice_concept/.envrc" "/home/williamp/chatservice_concept/.direnv"/*.rc

View File

@@ -1 +0,0 @@
/nix/store/gks035qaj52pl3ygwlicprsbqxw0wvja-source

View File

@@ -1 +0,0 @@
/nix/store/salp9r9j3pj9cwqf06wchs16hy8g882k-source

View File

@@ -1 +0,0 @@
/nix/store/78sjjah7cnj7zyhh9kq3yj1440rx0h56-nix-shell-env

9
.gitignore vendored
View File

@@ -7,6 +7,7 @@
*.dll
*.so
*.dylib
__debug_bin*
# Test binary, built with `go test -c`
*.test
@@ -22,4 +23,10 @@ go.work
go.work.sum
# env file
.env
.env
# Direnv directory
.direnv/
# Vscode directory
.vscode/

15
.vscode/launch.json vendored
View File

@@ -1,15 +0,0 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Launch Package",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "main.go"
}
]
}

7
LICENSE Normal file
View File

@@ -0,0 +1,7 @@
Copyright 2025 William Peebles
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE

View File

@@ -3,23 +3,49 @@ package api
import (
"flag"
"fmt"
"log/slog"
"net/http"
"git.dubyatp.xyz/chat-api-server/db"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/go-chi/docgen"
"github.com/go-chi/render"
)
var routes = flag.Bool("routes", false, "Generate API route documentation")
func RequestLog(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Debug("api: request received",
"request_uri", r.RequestURI,
"source_ip", r.RemoteAddr,
"user_agent", r.UserAgent())
next.ServeHTTP(w, r)
})
}
func Start() {
db.InitScyllaDB()
defer db.CloseScyllaDB()
flag.Parse()
r := chi.NewRouter()
r.Use(cors.Handler(cors.Options{
AllowedOrigins: []string{"http://localhost:5000"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: true,
MaxAge: 300, // Maximum value for preflight request cache
}))
r.Use(middleware.RequestID)
r.Use(middleware.Logger)
r.Use(RequestLog)
r.Use(middleware.Recoverer)
r.Use(middleware.URLFormat)
r.Use(render.SetContentType(render.ContentTypeJSON))
@@ -36,13 +62,50 @@ func Start() {
panic("oh no")
})
r.Route("/whoami", func(r chi.Router) {
r.Use(SessionAuthMiddleware)
r.Use(LoginCtx)
r.Get("/", Whoami)
})
r.Route("/messages", func(r chi.Router) {
r.Use(SessionAuthMiddleware) // Protect with authentication
r.Get("/", ListMessages)
r.Route("/{messageID}", func(r chi.Router) {
r.Use(MessageCtx) // Load message
r.Get("/", GetMessage)
r.Delete("/", DeleteMessage)
r.Post("/edit", EditMessage)
})
r.Post("/new", NewMessage)
r.Route("/new", func(r chi.Router) {
r.Use(LoginCtx)
r.Post("/", NewMessage)
})
})
r.Route("/users", func(r chi.Router) {
r.Use(SessionAuthMiddleware) // Protect with authentication
r.Get("/", ListUsers)
r.Route("/{userID}", func(r chi.Router) {
r.Use(UserCtx) // Load user
r.Get("/", GetUser)
})
})
r.Route("/login", func(r chi.Router) {
r.Post("/", Login)
})
r.Route("/logout", func(r chi.Router) {
r.Use(SessionAuthMiddleware)
r.Post("/", Logout)
})
r.Route("/register", func(r chi.Router) {
r.Post("/", NewUser)
})
if *routes {

226
api/auth.go Normal file
View File

@@ -0,0 +1,226 @@
package api
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
"net/http"
"os"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}
func Login(w http.ResponseWriter, r *http.Request) {
err := r.ParseMultipartForm(64 << 10)
if err != nil {
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
return
}
user, err := dbGetUserByName(username)
if err != nil {
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
err = validatePassword(user.Password, password)
if err != nil {
http.Error(w, "Invalid username or password", http.StatusUnauthorized)
return
}
sessionToken := CreateSession(user.ID)
http.SetCookie(w, &http.Cookie{
Name: "session_token",
Value: sessionToken,
Path: "/",
HttpOnly: true,
Secure: false,
})
slog.Info("auth: login successful", "userID", user.ID, "userName", user.Name)
w.Write([]byte("Login successful"))
}
func Logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_token")
if err != nil {
http.Error(w, "No session cookie found. You are already logged out", http.StatusBadRequest)
return
}
sessionToken := cookie.Value
userID, valid := ValidateSession(sessionToken)
if !valid {
http.Error(w, "Session cookie could not be validated. You are already logged out", http.StatusBadRequest)
return
}
user, err := dbGetUser(userID.String())
if err != nil {
http.Error(w, "Session cookie validated but user could not be found", http.StatusInternalServerError)
return
}
DeleteSession(sessionToken)
cookie.Expires = time.Now()
http.SetCookie(w, cookie)
slog.Debug("auth: logout successful", "userID", user.ID, "userName", user.Name)
w.Write([]byte(fmt.Sprintf("%v has been logged out", user.Name)))
}
type Session struct {
Token string
UserID uuid.UUID
Expiry time.Time
}
func CreateSession(userID uuid.UUID) string {
expiry := time.Now().Add(7 * 24 * time.Hour)
claims := jwt.MapClaims{
"userID": userID.String(),
"exp": expiry.Unix(), // 7 day token expiry
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(jwtSecret)
if err != nil {
slog.Error("auth: failed to create JWT", "error", err)
return ""
}
hashedToken := hashToken(tokenString)
session := Session{
Token: hashedToken,
UserID: userID,
Expiry: expiry,
}
dbAddSession(&session)
slog.Debug("auth: new session created", "userID", session.UserID)
return tokenString
}
func ValidateSession(sessionToken string) (uuid.UUID, bool) {
token, err := jwt.Parse(sessionToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
slog.Debug("auth: session token invalid, rejecting")
return uuid.Nil, false
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
slog.Debug("auth: could not map claims from JWT")
return uuid.Nil, false
}
userIDStr, ok := claims["userID"].(string)
if !ok {
slog.Debug("auth: userID claim is not a string")
return uuid.Nil, false
}
userID, err := uuid.Parse(userIDStr)
if err != nil {
slog.Debug("auth: failed to parse userID as uuid", "error", err)
return uuid.Nil, false
}
hashedToken := hashToken(sessionToken)
session, err := dbGetSession(hashedToken)
if err != nil {
slog.Debug("auth: failed to retrieve session from db", "error", err)
return uuid.Nil, false
}
if time.Now().After(session.Expiry) {
slog.Debug("auth: session is expired (or otherwise invalid) in db")
return uuid.Nil, false
}
slog.Debug("auth: session validated", "userID", session.UserID)
return userID, true
}
func DeleteSession(sessionToken string) bool {
hashedToken := hashToken(sessionToken)
err := dbDeleteSession(hashedToken)
if err != nil {
slog.Error("auth: failed to delete session", "error", err)
return false
}
slog.Debug("auth: session deleted", "token", hashedToken)
return true
}
type contextKey string
const userIDKey contextKey = "userID"
func SessionAuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("session_token")
if err != nil {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
sessionToken := cookie.Value
userID, valid := ValidateSession(sessionToken)
if !valid {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
// Add username to request context
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func (u *UserPayload) Render(w http.ResponseWriter, r *http.Request) error {
return nil
}
func hashPassword(password string) (string, error) {
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
return string(hashedPassword), err
}
func validatePassword(hashedPassword, password string) error {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
}

274
api/db.go
View File

@@ -3,97 +3,239 @@ package api
import (
"errors"
"fmt"
"time"
"log/slog"
"git.dubyatp.xyz/chat-api-server/db"
"github.com/gocql/gocql"
)
func dbGetUser(id int64) (*User, error) {
data := db.ExecDB("users")
if data == nil {
return nil, errors.New("failed to load users database")
func dbGetUser(id string) (*User, error) {
query := `SELECT id, name, password FROM users WHERE id = ?`
var user User
err := db.Session.Query(query, id).Scan(&user.ID, &user.Name, &user.Password)
if err == gocql.ErrNotFound {
slog.Debug("db: user not found", "userid", id)
return nil, errors.New("User not found")
} else if err != nil {
slog.Error("db: failed to query user", "error", err)
return nil, fmt.Errorf("failed to query user")
}
users := data["users"].([]interface{})
for _, u := range users {
user := u.(map[string]interface{})
if int64(user["ID"].(float64)) == id {
return &User{
ID: int64(user["ID"].(float64)),
Name: user["Name"].(string),
}, nil
}
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
return &user, nil
}
func dbGetUserByName(username string) (*User, error) {
query := `SELECT id, name, password FROM users WHERE name = ?`
var user User
err := db.Session.Query(query, username).Scan(&user.ID, &user.Name, &user.Password)
if err == gocql.ErrNotFound {
slog.Debug("db: user not found", "username", username)
return nil, errors.New("User not found")
} else if err != nil {
slog.Error("db: failed to query user", "error", err)
return nil, fmt.Errorf("failed to query user")
}
return nil, errors.New("User not found")
slog.Debug("db: user found", "userid", user.ID, "username", user.Name)
return &user, nil
}
func dbGetAllUsers() ([]*User, error) {
query := `SELECT id, name, password FROM users`
iter := db.Session.Query(query).Iter()
defer iter.Close()
var users []*User
for {
user := &User{}
if !iter.Scan(&user.ID, &user.Name, &user.Password) {
break
}
users = append(users, user)
}
if err := iter.Close(); err != nil {
slog.Error("db: failed to iterate users", "error", err)
return nil, fmt.Errorf("failed to iterate users")
}
if len(users) == 0 {
slog.Debug("db: no users found")
return nil, errors.New("no users found")
}
slog.Debug("db: user list returned")
return users, nil
}
func dbGetMessage(id string) (*Message, error) {
data := db.ExecDB("messages")
if data == nil {
return nil, errors.New("failed to load messages database")
query := `SELECT id, body, edited, timestamp, userid FROM messages WHERE id = ?`
var message Message
err := db.Session.Query(query, id).Scan(
&message.ID,
&message.Body,
&message.Edited,
&message.Timestamp,
&message.UserID)
if err == gocql.ErrNotFound {
slog.Debug("db: message not found", "messageid", id)
return nil, errors.New("Message not found")
} else if err != nil {
slog.Error("db: failed to query message", "error", err)
return nil, fmt.Errorf("failed to query message")
}
messages := data["messages"].([]interface{})
for _, m := range messages {
message := m.(map[string]interface{})
if message["ID"].(string) == id {
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
}
return &Message{
ID: message["ID"].(string),
UserID: int64(message["UserID"].(float64)),
Body: message["Body"].(string),
Timestamp: timestamp,
}, nil
}
}
return nil, errors.New("Message not found")
slog.Debug("db: message found", "messageid", message.ID)
return &message, nil
}
func dbGetAllMessages() ([]*Message, error) {
data := db.ExecDB("messages")
//println(data)
if data == nil {
return nil, errors.New("failed to load messages database")
query := `SELECT id, body, edited, timestamp, userid FROM messages`
iter := db.Session.Query(query).Iter()
defer iter.Close()
var messages []*Message
for {
message := &Message{}
if !iter.Scan(
&message.ID,
&message.Body,
&message.Edited,
&message.Timestamp,
&message.UserID) {
break
}
messages = append(messages, message)
}
messages := data["messages"].([]interface{})
var result []*Message
for _, m := range messages {
message := m.(map[string]interface{})
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
}
result = append(result, &Message{
ID: message["ID"].(string),
UserID: int64(message["UserID"].(float64)),
Body: message["Body"].(string),
Timestamp: timestamp,
})
if err := iter.Close(); err != nil {
slog.Error("db: failed to iterate messages", "error", err)
return nil, fmt.Errorf("failed to iterate messages")
}
if len(result) == 0 {
if len(messages) == 0 {
slog.Debug("db: no messages found")
return nil, errors.New("no messages found")
}
return result, nil
slog.Debug("db: message list returned")
return messages, nil
}
func dbAddUser(id int64, name string) error {
user := map[string]interface{}{
"ID": float64(id), // JSON numbers are float64 by default
"Name": name,
func dbAddSession(session *Session) error {
query := `INSERT INTO sessions (jwttoken, userid, expiry) VALUES (?, ?, ?)`
err := db.Session.Query(query, session.Token, session.UserID, session.Expiry).Exec()
if err != nil {
slog.Error("db: failed to add session", "error", err)
return fmt.Errorf("failed to add session")
}
return db.AddUser(user)
slog.Debug("db: session added", "userID", session.UserID)
return nil
}
func dbGetSession(jwtToken string) (*Session, error) {
query := `SELECT jwttoken, userid, expiry FROM sessions WHERE jwttoken = ?`
var session Session
err := db.Session.Query(query, jwtToken).Scan(
&session.Token,
&session.UserID,
&session.Expiry)
if err == gocql.ErrNotFound {
slog.Debug("db: session not found")
return nil, errors.New("Session not found")
} else if err != nil {
slog.Error("db: failed to query session", "error", err)
return nil, fmt.Errorf("failed to query session")
}
return &session, nil
}
func dbDeleteSession(jwtToken string) error {
query := `DELETE FROM sessions WHERE jwttoken = ?`
err := db.Session.Query(query, jwtToken).Exec()
if err != nil {
slog.Error("db: failed to delete session")
return fmt.Errorf("failed to delete session")
}
slog.Debug("db: session deleted")
return nil
}
func dbAddUser(user *User) error {
query := `INSERT INTO users (id, name, password) VALUES (?, ?, ?)`
err := db.Session.Query(query, user.ID, user.Name, user.Password).Exec()
if err != nil {
slog.Error("db: failed to add user", "error", err, "userid", user.ID, "username", user.Name)
return fmt.Errorf("failed to add user")
}
slog.Debug("db: user added", "userid", user.ID, "username", user.Name)
return nil
}
func dbAddMessage(message *Message) error {
dbMessage := map[string]interface{}{
"ID": message.ID,
"UserID": message.UserID, // JSON numbers are float64
"Body": message.Body,
"Timestamp": message.Timestamp,
query := `INSERT INTO messages (id, body, edited, timestamp, userid)
VALUES (?, ?, ?, ?, ?)`
err := db.Session.Query(query,
message.ID,
message.Body,
nil,
message.Timestamp,
message.UserID).Exec()
if err != nil {
slog.Error("db: failed to add message", "error", err, "messageid", message.ID)
return fmt.Errorf("failed to add message")
}
return db.AddMessage(dbMessage)
slog.Debug("db: message added", "messageid", message.ID)
return nil
}
func dbUpdateMessage(updatedMessage *Message) error {
var edited interface{}
if updatedMessage.Edited.IsZero() {
edited = nil
} else {
edited = updatedMessage.Edited
}
query := `UPDATE messages
SET body = ?, edited = ?, timestamp = ?
WHERE ID = ?`
err := db.Session.Query(query,
updatedMessage.Body,
edited,
updatedMessage.Timestamp,
updatedMessage.ID).Exec()
if err != nil {
slog.Error("db: failed to update message", "error", err, "messageid", updatedMessage.ID)
return fmt.Errorf("failed to update message")
}
slog.Debug("db: message updated", "messageid", updatedMessage.ID)
return nil
}
func dbDeleteMessage(id string) error {
query := `DELETE FROM messages WHERE ID = ?`
err := db.Session.Query(query, id).Exec()
if err != nil {
slog.Error("db: failed to delete message", "error", err, "messageid", id)
return fmt.Errorf("failed to delete message")
}
slog.Debug("db: message deleted", "messageid", id)
return nil
}

View File

@@ -3,6 +3,7 @@ package api
import (
"context"
"encoding/json"
"log/slog"
"net/http"
"github.com/go-chi/chi/v5"
@@ -15,86 +16,179 @@ import (
func MessageCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Debug("message: entering MessageCtx middleware")
var message *Message
var err error
if messageID := chi.URLParam(r, "messageID"); messageID != "" {
slog.Debug("message: fetching message", "messageID", messageID)
message, err = dbGetMessage(messageID)
} else {
slog.Error("message: messageID not found in URL parameters")
render.Render(w, r, ErrNotFound)
return
}
if err != nil {
slog.Error("message: failed to fetch message", "messageID", chi.URLParam(r, "messageID"), "error", err)
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("message: successfully fetched message", "messageID", message.ID)
ctx := context.WithValue(r.Context(), messageKey{}, message)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetMessage(w http.ResponseWriter, r *http.Request) {
slog.Debug("message: entering GetMessage handler")
message, ok := r.Context().Value(messageKey{}).(*Message)
if !ok || message == nil {
slog.Error("message: message not found in context")
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("message: rendering message", "messageID", message.ID)
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
slog.Error("message: failed to render message response", "messageID", message.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func EditMessage(w http.ResponseWriter, r *http.Request) {
slog.Debug("message: entering EditMessage handler")
message, ok := r.Context().Value(messageKey{}).(*Message)
if !ok || message == nil {
slog.Error("message: message not found in context")
render.Render(w, r, ErrNotFound)
return
}
err := r.ParseMultipartForm(64 << 10)
if err != nil {
slog.Error("message: failed to parse multipart form", "error", err)
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
body := r.FormValue("body")
if body == "" {
slog.Error("message: message body is empty")
http.Error(w, "Message body cannot be empty", http.StatusBadRequest)
return
}
slog.Debug("message: updating message", "messageID", message.ID)
message.Body = body
editedTime := time.Now()
message.Edited = &editedTime
err = dbUpdateMessage(message)
if err != nil {
slog.Error("message: failed to update message", "messageID", message.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("message: successfully updated message", "messageID", message.ID)
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
slog.Error("message: failed to render updated message response", "messageID", message.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func DeleteMessage(w http.ResponseWriter, r *http.Request) {
slog.Debug("message: entering DeleteMessage handler")
message, ok := r.Context().Value(messageKey{}).(*Message)
if !ok || message == nil {
slog.Error("message: message not found in context")
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("message: deleting message", "messageID", message.ID)
err := dbDeleteMessage(message.ID.String())
if err != nil {
slog.Error("message: failed to delete message", "messageID", message.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("message: successfully deleted message", "messageID", message.ID)
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
slog.Error("message: failed to render deleted message response", "messageID", message.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func ListMessages(w http.ResponseWriter, r *http.Request) {
slog.Debug("message: entering ListMessages handler")
dbMessages, err := dbGetAllMessages()
if err != nil {
slog.Error("message: failed to fetch messages", "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("message: successfully fetched messages", "count", len(dbMessages))
if err := render.RenderList(w, r, NewMessageListResponse(dbMessages)); err != nil {
slog.Error("message: failed to render message list response", "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func newMessageID() string {
return "msg_" + uuid.New().String()
func newMessageID() uuid.UUID {
return uuid.New()
}
func NewMessage(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var msg Message
err := json.NewDecoder(r.Body).Decode(&msg)
slog.Debug("message: entering NewMessage handler")
err := r.ParseMultipartForm(64 << 10)
if err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
slog.Error("message: failed to parse multipart form", "error", err)
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
msg.ID = newMessageID()
var user = r.Context().Value(userKey{}).(*User)
body := r.FormValue("body")
msg.Timestamp = time.Now()
if body == "" {
slog.Error("message: message body is empty")
http.Error(w, "Invalid body", http.StatusBadRequest)
return
}
msg := Message{
ID: newMessageID(),
UserID: user.ID,
Body: body,
Timestamp: time.Now(),
}
slog.Debug("message: creating new message", "messageID", msg.ID)
err = dbAddMessage(&msg)
if err != nil {
slog.Error("message: failed to add new message", "messageID", msg.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("message: successfully created new message", "messageID", msg.ID)
render.Render(w, r, NewMessageResponse(&msg))
}
type messageKey struct{}
type Message struct {
ID string `json:"id"`
UserID int64 `json:"user_id"`
Body string `json:"body"`
Timestamp time.Time `json:"timestamp"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Body string `json:"body"`
Timestamp time.Time `json:"timestamp"`
Edited *time.Time `json:"edited"`
}
type MessageRequest struct {
@@ -115,19 +209,27 @@ type MessageResponse struct {
func (m MessageResponse) MarshalJSON() ([]byte, error) {
type OrderedMessageResponse struct {
ID string `json:"id"`
UserID int64 `json:"user_id"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Body string `json:"body"`
Timestamp string `json:"timestamp"`
Edited *string `json:"edited,omitempty"` // Use a pointer to allow null values
User *UserPayload `json:"user,omitempty"`
Elapsed int64 `json:"elapsed"`
}
var edited *string
if m.Message.Edited != nil { // Check if Edited is not the zero value
editedStr := m.Message.Edited.Format(time.RFC3339)
edited = &editedStr
}
ordered := OrderedMessageResponse{
ID: m.Message.ID,
UserID: m.Message.UserID,
Body: m.Message.Body,
Timestamp: m.Message.Timestamp.Format(time.RFC3339),
Edited: edited, // Null if Edited is zero
User: m.User,
Elapsed: m.Elapsed,
}

View File

@@ -10,7 +10,7 @@ func NewMessageResponse(message *Message) *MessageResponse {
resp := &MessageResponse{Message: message}
if resp.User == nil {
if user, _ := dbGetUser(resp.UserID); user != nil {
if user, _ := dbGetUser(resp.UserID.String()); user != nil {
resp.User = NewUserPayloadResponse(user)
}
}
@@ -30,6 +30,14 @@ func NewMessageListResponse(messages []*Message) []render.Renderer {
return list
}
func NewUserListResponse(users []*User) []render.Renderer {
list := []render.Renderer{}
for _, user := range users {
list = append(list, NewUserPayloadResponse(user))
}
return list
}
func NewUserPayloadResponse(user *User) *UserPayload {
return &UserPayload{User: user}
}

View File

@@ -1,8 +1,165 @@
package api
import (
"context"
"log/slog"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/google/uuid"
)
func UserCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering UserCtx middleware")
var user *User
var err error
if userID := chi.URLParam(r, "userID"); userID != "" {
slog.Debug("user: fetching user", "userID", userID)
user, err = dbGetUser(userID)
} else {
slog.Error("user: userID not found in URL parameters")
render.Render(w, r, ErrNotFound)
return
}
if err != nil {
slog.Error("user: failed to fetch user", "userID", user.ID, "error", err)
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("user: successfully fetched user", "userID", user.ID)
ctx := context.WithValue(r.Context(), userKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func Whoami(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering Whoami handler")
user, ok := r.Context().Value(userKey{}).(*User)
if !ok || user == nil {
slog.Debug("user: anonymous user")
w.Write([]byte("anonymous"))
return
}
slog.Debug("user: returning user name", "userID", user.ID, "userName", user.Name)
w.Write([]byte(user.Name))
}
func LoginCtx(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering LoginCtx middleware")
userID, ok := r.Context().Value(userIDKey).(uuid.UUID)
if !ok || userID == uuid.Nil {
slog.Debug("user: no user ID provided, assuming anonymous user")
next.ServeHTTP(w, r)
return
}
slog.Debug("user: fetching user by user ID", "userID", userID)
user, err := dbGetUser(userID.String())
if err != nil {
slog.Error("user: failed to fetch user by user ID", "userID", userID, "error", err)
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("user: successfully fetched user", "userID", user.ID, "username", user.Name)
ctx := context.WithValue(r.Context(), userKey{}, user)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func GetUser(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering GetUser handler")
user, ok := r.Context().Value(userKey{}).(*User)
if !ok || user == nil {
slog.Error("user: user not found in context")
render.Render(w, r, ErrNotFound)
return
}
slog.Debug("user: rendering user", "userID", user.ID, "userName", user.Name)
if err := render.Render(w, r, NewUserPayloadResponse(user)); err != nil {
slog.Error("user: failed to render user response", "userID", user.ID, "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func ListUsers(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering ListUsers handler")
dbUsers, err := dbGetAllUsers()
if err != nil {
slog.Error("user: failed to fetch users", "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("user: successfully fetched users", "count", len(dbUsers))
if err := render.RenderList(w, r, NewUserListResponse(dbUsers)); err != nil {
slog.Error("user: failed to render user list response", "error", err)
render.Render(w, r, ErrRender(err))
return
}
}
func newUserID() uuid.UUID {
return uuid.New()
}
func NewUser(w http.ResponseWriter, r *http.Request) {
slog.Debug("user: entering NewUser handler")
err := r.ParseMultipartForm(64 << 10)
if err != nil {
slog.Error("user: failed to parse multipart form", "error", err)
http.Error(w, "Unable to parse form", http.StatusBadRequest)
return
}
newUserName := r.FormValue("name")
password := r.FormValue("password")
if newUserName == "" || password == "" {
slog.Error("user: username or password is empty")
http.Error(w, "Username and password cannot be empty", http.StatusBadRequest)
return
}
slog.Debug("user: hashing password for new user", "userName", newUserName)
hashedPassword, err := hashPassword(password)
if err != nil {
slog.Error("user: failed to hash password", "error", err)
http.Error(w, "Unable to hash password", http.StatusInternalServerError)
return
}
newUser := User{
ID: newUserID(),
Name: newUserName,
Password: hashedPassword,
}
slog.Debug("user: adding new user to database", "userID", newUser.ID, "userName", newUser.Name)
err = dbAddUser(&newUser)
if err != nil {
slog.Error("user: failed to add new user", "userID", newUser.ID, "userName", newUser.Name, "error", err)
render.Render(w, r, ErrRender(err))
return
}
slog.Debug("user: successfully added new user", "userID", newUser.ID, "userName", newUser.Name)
render.Render(w, r, NewUserPayloadResponse(&newUser))
}
type userKey struct{}
type User struct {
ID int64 `json:"id"`
Name string `json:"name"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Password string `json:"-"`
}
type UserPayload struct {

View File

@@ -1,95 +0,0 @@
package db
import (
"encoding/json"
"fmt"
"io"
"os"
)
func ExecDB(db_name string) map[string]interface{} {
var result map[string]interface{}
if db_name == "users" {
users_db, err := os.Open("./test_data/users.json")
if err != nil {
fmt.Println(err)
return nil
}
fmt.Println("Successfully opened Users DB")
defer users_db.Close()
byteValue, _ := io.ReadAll(users_db)
var users []interface{}
json.Unmarshal(byteValue, &users)
result = map[string]interface{}{"users": users}
} else if db_name == "messages" {
messages_db, err := os.Open("./test_data/messages.json")
if err != nil {
fmt.Println(err)
return nil
}
fmt.Println("Successfully opened Messages DB")
defer messages_db.Close()
byteValue, _ := io.ReadAll(messages_db)
var messages []interface{}
json.Unmarshal(byteValue, &messages)
result = map[string]interface{}{"messages": messages}
} else {
fmt.Println("Invalid DB name")
return nil
}
return result
}
func WriteDB(db_name string, data interface{}) error {
var filePath string
switch db_name {
case "users":
filePath = "./test_data/users.json"
case "messages":
filePath = "./test_data/messages.json"
default:
return fmt.Errorf("invalid database name: %s", db_name)
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("error marshaling data to JSON: %v", err)
}
err = os.WriteFile(filePath, jsonData, 0644)
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
fmt.Printf("Successfully wrote to %s DB\n", db_name)
return nil
}
func AddUser(user map[string]interface{}) error {
currentData := ExecDB("users")
if currentData == nil {
return fmt.Errorf("error reading users database")
}
users := currentData["users"].([]interface{})
users = append(users, user)
return WriteDB("users", users)
}
func AddMessage(message map[string]interface{}) error {
currentData := ExecDB("messages")
if currentData == nil {
return fmt.Errorf("error reading messages database")
}
messages := currentData["messages"].([]interface{})
messages = append(messages, message)
return WriteDB("messages", messages)
}

30
db/scylla.go Normal file
View File

@@ -0,0 +1,30 @@
package db
import (
"log/slog"
"os"
"github.com/gocql/gocql"
)
var Session *gocql.Session
func InitScyllaDB() {
cluster := gocql.NewCluster(os.Getenv("SCYLLA_CLUSTER")) // Replace with your ScyllaDB cluster IPs
cluster.Keyspace = os.Getenv("SCYLLA_KEYSPACE") // Replace with your keyspace
cluster.Consistency = gocql.Quorum
session, err := cluster.CreateSession()
if err != nil {
slog.Error("Failed to connect to ScyllaDB", "error", err)
os.Exit(1)
}
Session = session
slog.Info("Connected to ScyllaDB")
}
func CloseScyllaDB() {
if Session != nil {
Session.Close()
}
}

View File

@@ -7,10 +7,24 @@
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/**
- _GET_
- [Start.func1]()
- _GET_
- [Start.func1]()
</details>
<details>
<summary>`/login`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/login**
- **/**
- _POST_
- [Login]()
</details>
<details>
@@ -20,11 +34,27 @@
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/messages**
- **/**
- _GET_
- [ListMessages]()
- [SessionAuthMiddleware]()
- **/**
- _GET_
- [ListMessages]()
</details>
<details>
<summary>`/messages/new`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/messages**
- [SessionAuthMiddleware]()
- **/new**
- _POST_
- [NewMessage]()
</details>
<details>
@@ -34,13 +64,33 @@
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/messages**
- **/{messageID}**
- [MessageCtx]()
- **/**
- _GET_
- [GetMessage]()
- [SessionAuthMiddleware]()
- **/{messageID}**
- [MessageCtx]()
- **/**
- _GET_
- [GetMessage]()
- _DELETE_
- [DeleteMessage]()
</details>
<details>
<summary>`/messages/{messageID}/edit`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/messages**
- [SessionAuthMiddleware]()
- **/{messageID}**
- [MessageCtx]()
- **/edit**
- _POST_
- [EditMessage]()
</details>
<details>
@@ -50,10 +100,10 @@
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/panic**
- _GET_
- [Start.func3]()
- _GET_
- [Start.func3]()
</details>
<details>
@@ -63,11 +113,58 @@
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func5]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/ping**
- _GET_
- [Start.func2]()
- _GET_
- [Start.func2]()
</details>
<details>
<summary>`/register`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/register**
- **/**
- _POST_
- [NewUser]()
</details>
<details>
<summary>`/users`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/users**
- [SessionAuthMiddleware]()
- **/**
- _GET_
- [ListUsers]()
</details>
<details>
<summary>`/users/{userID}`</summary>
- [RequestID]()
- [Logger]()
- [Recoverer]()
- [URLFormat]()
- [git.dubyatp.xyz/chat-api-server/api.Start.SetContentType.func8]()
- **/users**
- [SessionAuthMiddleware]()
- **/{userID}**
- [UserCtx]()
- **/**
- _GET_
- [GetUser]()
</details>
Total # of routes: 5
Total # of routes: 11

View File

@@ -0,0 +1,4 @@
/node_modules/
/public/build/
.DS_Store

View File

@@ -0,0 +1,107 @@
# This repo is no longer maintained. Consider using `npm init vite` and selecting the `svelte` option or — if you want a full-fledged app framework — use [SvelteKit](https://kit.svelte.dev), the official application framework for Svelte.
---
# svelte app
This is a project template for [Svelte](https://svelte.dev) apps. It lives at https://github.com/sveltejs/template.
To create a new project based on this template using [degit](https://github.com/Rich-Harris/degit):
```bash
npx degit sveltejs/template svelte-app
cd svelte-app
```
*Note that you will need to have [Node.js](https://nodejs.org) installed.*
## Get started
Install the dependencies...
```bash
cd svelte-app
npm install
```
...then start [Rollup](https://rollupjs.org):
```bash
npm run dev
```
Navigate to [localhost:8080](http://localhost:8080). You should see your app running. Edit a component file in `src`, save it, and reload the page to see your changes.
By default, the server will only respond to requests from localhost. To allow connections from other computers, edit the `sirv` commands in package.json to include the option `--host 0.0.0.0`.
If you're using [Visual Studio Code](https://code.visualstudio.com/) we recommend installing the official extension [Svelte for VS Code](https://marketplace.visualstudio.com/items?itemName=svelte.svelte-vscode). If you are using other editors you may need to install a plugin in order to get syntax highlighting and intellisense.
## Building and running in production mode
To create an optimised version of the app:
```bash
npm run build
```
You can run the newly built app with `npm run start`. This uses [sirv](https://github.com/lukeed/sirv), which is included in your package.json's `dependencies` so that the app will work when you deploy to platforms like [Heroku](https://heroku.com).
## Single-page app mode
By default, sirv will only respond to requests that match files in `public`. This is to maximise compatibility with static fileservers, allowing you to deploy your app anywhere.
If you're building a single-page app (SPA) with multiple routes, sirv needs to be able to respond to requests for *any* path. You can make it so by editing the `"start"` command in package.json:
```js
"start": "sirv public --single"
```
## Using TypeScript
This template comes with a script to set up a TypeScript development environment, you can run it immediately after cloning the template with:
```bash
node scripts/setupTypeScript.js
```
Or remove the script via:
```bash
rm scripts/setupTypeScript.js
```
If you want to use `baseUrl` or `path` aliases within your `tsconfig`, you need to set up `@rollup/plugin-alias` to tell Rollup to resolve the aliases. For more info, see [this StackOverflow question](https://stackoverflow.com/questions/63427935/setup-tsconfig-path-in-svelte).
## Deploying to the web
### With [Vercel](https://vercel.com)
Install `vercel` if you haven't already:
```bash
npm install -g vercel
```
Then, from within your project folder:
```bash
cd public
vercel deploy --name my-project
```
### With [surge](https://surge.sh/)
Install `surge` if you haven't already:
```bash
npm install -g surge
```
Then, from within your project folder:
```bash
npm run build
surge public my-project.surge.sh
```

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,26 @@
{
"name": "svelte-app",
"version": "1.0.0",
"private": true,
"type": "module",
"scripts": {
"build": "rollup -c",
"dev": "rollup -c -w",
"start": "sirv public --single"
},
"devDependencies": {
"@rollup/plugin-commonjs": "^24.0.0",
"@rollup/plugin-node-resolve": "^15.0.0",
"@rollup/plugin-terser": "^0.4.0",
"rollup": "^3.15.0",
"rollup-plugin-css-only": "^4.3.0",
"rollup-plugin-livereload": "^2.0.0",
"rollup-plugin-svelte": "^7.1.2",
"svelte": "^3.55.0"
},
"dependencies": {
"axios": "^1.9.0",
"sirv-cli": "^2.0.0",
"svelte-spa-router": "^4.0.1"
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 KiB

View File

@@ -0,0 +1,63 @@
html, body {
position: relative;
width: 100%;
height: 100%;
}
body {
color: #333;
margin: 0;
padding: 8px;
box-sizing: border-box;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Oxygen-Sans, Ubuntu, Cantarell, "Helvetica Neue", sans-serif;
}
a {
color: rgb(0,100,200);
text-decoration: none;
}
a:hover {
text-decoration: underline;
}
a:visited {
color: rgb(0,80,160);
}
label {
display: block;
}
input, button, select, textarea {
font-family: inherit;
font-size: inherit;
-webkit-padding: 0.4em 0;
padding: 0.4em;
margin: 0 0 0.5em 0;
box-sizing: border-box;
border: 1px solid #ccc;
border-radius: 2px;
}
input:disabled {
color: #ccc;
}
button {
color: #333;
background-color: #f4f4f4;
outline: none;
}
button:disabled {
color: #999;
}
button:not(:disabled):active {
background-color: #ddd;
}
button:focus {
border-color: #666;
}

View File

@@ -0,0 +1,18 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset='utf-8'>
<meta name='viewport' content='width=device-width,initial-scale=1'>
<title>Svelte app</title>
<link rel='icon' type='image/png' href='/favicon.png'>
<link rel='stylesheet' href='/global.css'>
<link rel='stylesheet' href='/build/bundle.css'>
<script defer src='/build/bundle.js'></script>
</head>
<body>
</body>
</html>

View File

@@ -0,0 +1,78 @@
import { spawn } from 'child_process';
import svelte from 'rollup-plugin-svelte';
import commonjs from '@rollup/plugin-commonjs';
import terser from '@rollup/plugin-terser';
import resolve from '@rollup/plugin-node-resolve';
import livereload from 'rollup-plugin-livereload';
import css from 'rollup-plugin-css-only';
const production = !process.env.ROLLUP_WATCH;
function serve() {
let server;
function toExit() {
if (server) server.kill(0);
}
return {
writeBundle() {
if (server) return;
server = spawn('npm', ['run', 'start', '--', '--dev'], {
stdio: ['ignore', 'inherit', 'inherit'],
shell: true
});
process.on('SIGTERM', toExit);
process.on('exit', toExit);
}
};
}
export default {
input: 'src/main.js',
output: {
sourcemap: true,
format: 'iife',
name: 'app',
file: 'public/build/bundle.js'
},
plugins: [
svelte({
compilerOptions: {
// enable run-time checks when not in production
dev: !production
}
}),
// we'll extract any component CSS out into
// a separate file - better for performance
css({ output: 'bundle.css' }),
// If you have external dependencies installed from
// npm, you'll most likely need these plugins. In
// some cases you'll need additional configuration -
// consult the documentation for details:
// https://github.com/rollup/plugins/tree/master/packages/commonjs
resolve({
browser: true,
dedupe: ['svelte'],
exportConditions: ['svelte']
}),
commonjs(),
// In dev mode, call `npm run start` once
// the bundle has been generated
!production && serve(),
// Watch the `public` directory and refresh the
// browser on changes when not in production
!production && livereload('public'),
// If we're building for production (npm run build
// instead of npm run dev), minify
production && terser()
],
watch: {
clearScreen: false
}
};

View File

@@ -0,0 +1,134 @@
// @ts-check
/** This script modifies the project to support TS code in .svelte files like:
<script lang="ts">
export let name: string;
</script>
As well as validating the code for CI.
*/
/** To work on this script:
rm -rf test-template template && git clone sveltejs/template test-template && node scripts/setupTypeScript.js test-template
*/
import fs from "fs"
import path from "path"
import { argv } from "process"
import url from 'url';
const __filename = url.fileURLToPath(import.meta.url);
const __dirname = url.fileURLToPath(new URL('.', import.meta.url));
const projectRoot = argv[2] || path.join(__dirname, "..")
// Add deps to pkg.json
const packageJSON = JSON.parse(fs.readFileSync(path.join(projectRoot, "package.json"), "utf8"))
packageJSON.devDependencies = Object.assign(packageJSON.devDependencies, {
"svelte-check": "^3.0.0",
"svelte-preprocess": "^5.0.0",
"@rollup/plugin-typescript": "^11.0.0",
"typescript": "^4.9.0",
"tslib": "^2.5.0",
"@tsconfig/svelte": "^3.0.0"
})
// Add script for checking
packageJSON.scripts = Object.assign(packageJSON.scripts, {
"check": "svelte-check"
})
// Write the package JSON
fs.writeFileSync(path.join(projectRoot, "package.json"), JSON.stringify(packageJSON, null, " "))
// mv src/main.js to main.ts - note, we need to edit rollup.config.js for this too
const beforeMainJSPath = path.join(projectRoot, "src", "main.js")
const afterMainTSPath = path.join(projectRoot, "src", "main.ts")
fs.renameSync(beforeMainJSPath, afterMainTSPath)
// Switch the app.svelte file to use TS
const appSveltePath = path.join(projectRoot, "src", "App.svelte")
let appFile = fs.readFileSync(appSveltePath, "utf8")
appFile = appFile.replace("<script>", '<script lang="ts">')
appFile = appFile.replace("export let name;", 'export let name: string;')
fs.writeFileSync(appSveltePath, appFile)
// Edit rollup config
const rollupConfigPath = path.join(projectRoot, "rollup.config.js")
let rollupConfig = fs.readFileSync(rollupConfigPath, "utf8")
// Edit imports
rollupConfig = rollupConfig.replace(`'rollup-plugin-css-only';`, `'rollup-plugin-css-only';
import sveltePreprocess from 'svelte-preprocess';
import typescript from '@rollup/plugin-typescript';`)
// Replace name of entry point
rollupConfig = rollupConfig.replace(`'src/main.js'`, `'src/main.ts'`)
// Add preprocessor
rollupConfig = rollupConfig.replace(
'compilerOptions:',
'preprocess: sveltePreprocess({ sourceMap: !production }),\n\t\t\tcompilerOptions:'
);
// Add TypeScript
rollupConfig = rollupConfig.replace(
'commonjs(),',
'commonjs(),\n\t\ttypescript({\n\t\t\tsourceMap: !production,\n\t\t\tinlineSources: !production\n\t\t}),'
);
fs.writeFileSync(rollupConfigPath, rollupConfig)
// Add svelte.config.js
const tsconfig = `{
"extends": "@tsconfig/svelte/tsconfig.json",
"include": ["src/**/*"],
"exclude": ["node_modules/*", "__sapper__/*", "public/*"]
}`
const tsconfigPath = path.join(projectRoot, "tsconfig.json")
fs.writeFileSync(tsconfigPath, tsconfig)
// Add TSConfig
const svelteConfig = `import sveltePreprocess from 'svelte-preprocess';
export default {
preprocess: sveltePreprocess()
};
`
const svelteConfigPath = path.join(projectRoot, "svelte.config.js")
fs.writeFileSync(svelteConfigPath, svelteConfig)
// Add global.d.ts
const dtsPath = path.join(projectRoot, "src", "global.d.ts")
fs.writeFileSync(dtsPath, `/// <reference types="svelte" />`)
// Delete this script, but not during testing
if (!argv[2]) {
// Remove the script
fs.unlinkSync(path.join(__filename))
// Check for Mac's DS_store file, and if it's the only one left remove it
const remainingFiles = fs.readdirSync(path.join(__dirname))
if (remainingFiles.length === 1 && remainingFiles[0] === '.DS_store') {
fs.unlinkSync(path.join(__dirname, '.DS_store'))
}
// Check if the scripts folder is empty
if (fs.readdirSync(path.join(__dirname)).length === 0) {
// Remove the scripts folder
fs.rmdirSync(path.join(__dirname))
}
}
// Adds the extension recommendation
fs.mkdirSync(path.join(projectRoot, ".vscode"), { recursive: true })
fs.writeFileSync(path.join(projectRoot, ".vscode", "extensions.json"), `{
"recommendations": ["svelte.svelte-vscode"]
}
`)
console.log("Converted to TypeScript.")
if (fs.existsSync(path.join(projectRoot, "node_modules"))) {
console.log("\nYou will need to re-run your dependency manager to get started.")
}

View File

@@ -0,0 +1,6 @@
<script>
import Router from 'svelte-spa-router';
import routes from './routes.js';
</script>
<Router {routes} />

View File

@@ -0,0 +1,36 @@
import axios from 'axios';
const API_BASE_URL = 'http://localhost:3000';
export const login = async (username, password) => {
const formData = new FormData();
formData.append('username', username);
formData.append('password', password);
const response = await axios.post(`${API_BASE_URL}/login`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
withCredentials: true,
});
return response.data;
};
export const fetchMessages = async () => {
const response = await axios.get(`${API_BASE_URL}/messages`, { withCredentials: true });
return response.data;
};
export const createMessage = async (body) => {
const formData = new FormData();
formData.append('body', body);
const response = await axios.post(`${API_BASE_URL}/messages/new`, formData, {
headers: {
'Content-Type': 'multipart/form-data',
},
withCredentials: true,
});
return response.data;
};

View File

@@ -0,0 +1,10 @@
import App from './App.svelte';
const app = new App({
target: document.body,
props: {
name: 'world'
}
});
export default app;

View File

@@ -0,0 +1,7 @@
import Login from './routes/Login.svelte';
import Messages from './routes/Messages.svelte';
export default {
'/': Login,
'/messages': Messages,
};

View File

@@ -0,0 +1,22 @@
<script>
import { login } from '../api.js';
import { push } from 'svelte-spa-router';
let username = '';
let password = '';
let error = '';
const handleLogin = async () => {
try {
await login(username, password);
push('/messages');
} catch (err) {
error = err;
}
};
</script>
<h1>Login</h1>
<input type="text" bind:value={username} placeholder="Username" />
<input type="password" bind:value={password} placeholder="Password" />
<button on:click={handleLogin}>Login</button>
<p style="color: red">{error}</p>

View File

@@ -0,0 +1,26 @@
<script>
import { fetchMessages, createMessage } from '../api.js';
let messages = [];
let newMessage = '';
const loadMessages = async () => {
messages = await fetchMessages();
};
const handleCreateMessage = async () => {
await createMessage(newMessage);
newMessage = '';
await loadMessages();
};
loadMessages();
</script>
<h1>Messages</h1>
<ul>
{#each messages as message}
<li>{message.user.name} - {message.body} - {message.timestamp}</li>
{/each}
</ul>
<input type="text" bind:value={newMessage} placeholder="New message" />
<button on:click={handleCreateMessage}>Send</button>

8
flake.lock generated
View File

@@ -2,16 +2,16 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1735264675,
"narHash": "sha256-MgdXpeX2GuJbtlBrH9EdsUeWl/yXEubyvxM1G+yO4Ak=",
"lastModified": 1743827369,
"narHash": "sha256-rpqepOZ8Eo1zg+KJeWoq1HAOgoMCDloqv5r2EAa9TSA=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "d49da4c08359e3c39c4e27c74ac7ac9b70085966",
"rev": "42a1c966be226125b48c384171c44c651c236c22",
"type": "github"
},
"original": {
"id": "nixpkgs",
"ref": "nixos-24.11",
"ref": "nixos-unstable",
"type": "indirect"
}
},

View File

@@ -1,7 +1,9 @@
{
description = "Unnamed Chat Server API";
inputs.nixpkgs.url = "nixpkgs/nixos-24.11";
inputs = {
nixpkgs.url = "nixpkgs/nixos-unstable";
};
outputs = { self, nixpkgs }:
let
@@ -33,10 +35,11 @@
default = pkgs.mkShell {
hardeningDisable = [ "fortify" ];
buildInputs = [
pkgs.bashInteractive
pkgs.go
pkgs.delve
];
};
});
};
}
}

22
go.mod
View File

@@ -1,12 +1,30 @@
module git.dubyatp.xyz/chat-api-server
go 1.23
go 1.23.0
toolchain go1.23.3
require (
github.com/go-chi/chi/v5 v5.2.0
github.com/go-chi/docgen v1.3.0
github.com/go-chi/render v1.0.3
github.com/gocql/gocql v1.7.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1
)
require github.com/ajg/form v1.5.1 // indirect
require github.com/go-chi/cors v1.2.1
require (
github.com/golang-jwt/jwt/v5 v5.2.2
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/klauspost/compress v1.17.9 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
)
require (
github.com/ajg/form v1.5.1 // indirect
golang.org/x/crypto v0.36.0
)
replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.5

56
go.sum
View File

@@ -1,14 +1,70 @@
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi/v5 v5.0.1/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/chi/v5 v5.2.0 h1:Aj1EtB0qR2Rdo2dG4O94RIU35w2lvQSj6BRA4+qwFL0=
github.com/go-chi/chi/v5 v5.2.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4=
github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58=
github.com/go-chi/docgen v1.3.0 h1:dmDJ2I+EJfCTrxfgxQDwfR/OpZLTRFKe7EKB8v7yuxI=
github.com/go-chi/docgen v1.3.0/go.mod h1:G9W0G551cs2BFMSn/cnGwX+JBHEloAgo17MBhyrnhPI=
github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns=
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/scylladb/gocql v1.14.5 h1:lyJKf0m/Vate+8MGiVeRhQNpLVVsL21gvp89zEZdltI=
github.com/scylladb/gocql v1.14.5/go.mod h1:1efi3H0Gr72WCR0W+i+d63FmwmJhDL/zfAC0gMJHVlM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo=
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=

31
log/log.go Normal file
View File

@@ -0,0 +1,31 @@
package log
import (
"flag"
"log/slog"
"os"
)
func Logger() {
// Get logger arguments
var loglevelStr string
flag.StringVar(&loglevelStr, "loglevel", "ERROR", "set log level")
flag.Parse()
loglevel := new(slog.LevelVar)
if loglevelStr == "DEBUG" {
loglevel.Set(slog.LevelDebug)
} else if loglevelStr == "INFO" {
loglevel.Set(slog.LevelInfo)
} else if loglevelStr == "WARN" {
loglevel.Set(slog.LevelWarn)
} else {
loglevel.Set(slog.LevelError)
}
// Start logger
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
Level: loglevel,
}))
slog.SetDefault(logger)
slog.Debug("Logging started", "level", loglevel)
}

34
main.go
View File

@@ -1,9 +1,43 @@
package main
import (
"log/slog"
"os"
"git.dubyatp.xyz/chat-api-server/api"
"git.dubyatp.xyz/chat-api-server/log"
"github.com/joho/godotenv"
)
func checkEnvVars(keys []string) (bool, []string) {
var missing []string
for _, key := range keys {
if _, ok := os.LookupEnv(key); !ok {
missing = append(missing, key)
}
}
return len(missing) == 0, missing
}
func main() {
log.Logger() // initialize logger
requiredEnvVars := []string{"SCYLLA_CLUSTER", "SCYLLA_KEYSPACE"}
// Initialize dotenv file
err := godotenv.Load()
if err != nil {
slog.Info("No .env file loaded, will try OS environment variables")
}
// Check if environment variables were defined by the OS before erroring out
exists, missingVars := checkEnvVars(requiredEnvVars)
if !exists {
slog.Error("Missing environment variables", "missing", missingVars)
os.Exit(1)
}
slog.Info("Starting the API Server...")
api.Start()
}

View File

@@ -1,44 +0,0 @@
[
{
"Body": "hello",
"ID": "1",
"Timestamp": "2024-12-25T05:00:40Z",
"UserID": 1
},
{
"Body": "world",
"ID": "2",
"Timestamp": "2024-12-25T05:00:43Z",
"UserID": 2
},
{
"Body": "abababa",
"ID": "3",
"Timestamp": "2024-12-25T05:01:20Z",
"UserID": 1
},
{
"Body": "bitch",
"ID": "4",
"Timestamp": "2024-12-25T05:05:55Z",
"UserID": 2
},
{
"Body": "NIBBA",
"ID": "5",
"Timestamp": "2025-03-24T14:48:28.249221047-04:00",
"UserID": 1
},
{
"Body": "nibby",
"ID": "6",
"Timestamp": "2025-03-24T14:49:03.246929039-04:00",
"UserID": 1
},
{
"Body": "aaaaababananana",
"ID": "msg_60f70a47-3be2-4315-869a-d6f151ec262a",
"Timestamp": "2025-03-24T15:01:07.14371835-04:00",
"UserID": 1
}
]

View File

@@ -1,10 +0,0 @@
[
{
"ID": 1,
"Name": "duby"
},
{
"ID": 2,
"Name": "astolfo"
}
]

21
vendor/github.com/go-chi/cors/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,21 @@
Copyright (c) 2014 Olivier Poitrey <rs@dailymotion.com>
Copyright (c) 2016-Present https://github.com/go-chi authors
MIT License
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

39
vendor/github.com/go-chi/cors/README.md generated vendored Normal file
View File

@@ -0,0 +1,39 @@
# CORS net/http middleware
[go-chi/cors](https://github.com/go-chi/cors) is a fork of [github.com/rs/cors](https://github.com/rs/cors) that
provides a `net/http` compatible middleware for performing preflight CORS checks on the server side. These headers
are required for using the browser native [Fetch API](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API).
This middleware is designed to be used as a top-level middleware on the [chi](https://github.com/go-chi/chi) router.
Applying with within a `r.Group()` or using `With()` will not work without routes matching `OPTIONS` added.
## Usage
```go
func main() {
r := chi.NewRouter()
// Basic CORS
// for more ideas, see: https://developer.github.com/v3/#cross-origin-resource-sharing
r.Use(cors.Handler(cors.Options{
// AllowedOrigins: []string{"https://foo.com"}, // Use this to allow specific origin hosts
AllowedOrigins: []string{"https://*", "http://*"},
// AllowOriginFunc: func(r *http.Request, origin string) bool { return true },
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"},
ExposedHeaders: []string{"Link"},
AllowCredentials: false,
MaxAge: 300, // Maximum value not ignored by any of major browsers
}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
http.ListenAndServe(":3000", r)
}
```
## Credits
All credit for the original work of this middleware goes out to [github.com/rs](github.com/rs).

400
vendor/github.com/go-chi/cors/cors.go generated vendored Normal file
View File

@@ -0,0 +1,400 @@
// cors package is net/http handler to handle CORS related requests
// as defined by http://www.w3.org/TR/cors/
//
// You can configure it by passing an option struct to cors.New:
//
// c := cors.New(cors.Options{
// AllowedOrigins: []string{"foo.com"},
// AllowedMethods: []string{"GET", "POST", "DELETE"},
// AllowCredentials: true,
// })
//
// Then insert the handler in the chain:
//
// handler = c.Handler(handler)
//
// See Options documentation for more options.
//
// The resulting handler is a standard net/http handler.
package cors
import (
"log"
"net/http"
"os"
"strconv"
"strings"
)
// Options is a configuration container to setup the CORS middleware.
type Options struct {
// AllowedOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// An origin may contain a wildcard (*) to replace 0 or more characters
// (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
// Only one wildcard can be used per origin.
// Default value is ["*"]
AllowedOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the origin
// as argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowedOrigins is ignored.
AllowOriginFunc func(r *http.Request, origin string) bool
// AllowedMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (HEAD, GET and POST).
AllowedMethods []string
// AllowedHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
// If the special "*" value is present in the list, all headers will be allowed.
// Default value is [] but "Origin" is always appended to the list.
AllowedHeaders []string
// ExposedHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposedHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached
MaxAge int
// OptionsPassthrough instructs preflight to let other potential next handlers to
// process the OPTIONS method. Turn this on if your application handles OPTIONS.
OptionsPassthrough bool
// Debugging flag adds additional output to debug server side CORS issues
Debug bool
}
// Logger generic interface for logger
type Logger interface {
Printf(string, ...interface{})
}
// Cors http handler
type Cors struct {
// Debug logger
Log Logger
// Normalized list of plain allowed origins
allowedOrigins []string
// List of allowed origins containing wildcards
allowedWOrigins []wildcard
// Optional origin validator function
allowOriginFunc func(r *http.Request, origin string) bool
// Normalized list of allowed headers
allowedHeaders []string
// Normalized list of allowed methods
allowedMethods []string
// Normalized list of exposed headers
exposedHeaders []string
maxAge int
// Set to true when allowed origins contains a "*"
allowedOriginsAll bool
// Set to true when allowed headers contains a "*"
allowedHeadersAll bool
allowCredentials bool
optionPassthrough bool
}
// New creates a new Cors handler with the provided options.
func New(options Options) *Cors {
c := &Cors{
exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
allowOriginFunc: options.AllowOriginFunc,
allowCredentials: options.AllowCredentials,
maxAge: options.MaxAge,
optionPassthrough: options.OptionsPassthrough,
}
if options.Debug && c.Log == nil {
c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
}
// Normalize options
// Note: for origins and methods matching, the spec requires a case-sensitive matching.
// As it may error prone, we chose to ignore the spec here.
// Allowed Origins
if len(options.AllowedOrigins) == 0 {
if options.AllowOriginFunc == nil {
// Default is all origins
c.allowedOriginsAll = true
}
} else {
c.allowedOrigins = []string{}
c.allowedWOrigins = []wildcard{}
for _, origin := range options.AllowedOrigins {
// Normalize
origin = strings.ToLower(origin)
if origin == "*" {
// If "*" is present in the list, turn the whole list into a match all
c.allowedOriginsAll = true
c.allowedOrigins = nil
c.allowedWOrigins = nil
break
} else if i := strings.IndexByte(origin, '*'); i >= 0 {
// Split the origin in two: start and end string without the *
w := wildcard{origin[0:i], origin[i+1:]}
c.allowedWOrigins = append(c.allowedWOrigins, w)
} else {
c.allowedOrigins = append(c.allowedOrigins, origin)
}
}
}
// Allowed Headers
if len(options.AllowedHeaders) == 0 {
// Use sensible defaults
c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
} else {
// Origin is always appended as some browsers will always request for this header at preflight
c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
for _, h := range options.AllowedHeaders {
if h == "*" {
c.allowedHeadersAll = true
c.allowedHeaders = nil
break
}
}
}
// Allowed Methods
if len(options.AllowedMethods) == 0 {
// Default is spec's "simple" methods
c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead}
} else {
c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
}
return c
}
// Handler creates a new Cors handler with passed options.
func Handler(options Options) func(next http.Handler) http.Handler {
c := New(options)
return c.Handler
}
// AllowAll create a new Cors handler with permissive configuration allowing all
// origins with all standard methods with any header and credentials.
func AllowAll() *Cors {
return New(Options{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{
http.MethodHead,
http.MethodGet,
http.MethodPost,
http.MethodPut,
http.MethodPatch,
http.MethodDelete,
},
AllowedHeaders: []string{"*"},
AllowCredentials: false,
})
}
// Handler apply the CORS specification on the request, and add relevant CORS headers
// as necessary.
func (c *Cors) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
c.logf("Handler: Preflight request")
c.handlePreflight(w, r)
// Preflight requests are standalone and should stop the chain as some other
// middleware may not handle OPTIONS requests correctly. One typical example
// is authentication middleware ; OPTIONS requests won't carry authentication
// headers (see #1)
if c.optionPassthrough {
next.ServeHTTP(w, r)
} else {
w.WriteHeader(http.StatusOK)
}
} else {
c.logf("Handler: Actual request")
c.handleActualRequest(w, r)
next.ServeHTTP(w, r)
}
})
}
// handlePreflight handles pre-flight CORS requests
func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
origin := r.Header.Get("Origin")
if r.Method != http.MethodOptions {
c.logf("Preflight aborted: %s!=OPTIONS", r.Method)
return
}
// Always set Vary headers
// see https://github.com/rs/cors/issues/10,
// https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
headers.Add("Vary", "Origin")
headers.Add("Vary", "Access-Control-Request-Method")
headers.Add("Vary", "Access-Control-Request-Headers")
if origin == "" {
c.logf("Preflight aborted: empty origin")
return
}
if !c.isOriginAllowed(r, origin) {
c.logf("Preflight aborted: origin '%s' not allowed", origin)
return
}
reqMethod := r.Header.Get("Access-Control-Request-Method")
if !c.isMethodAllowed(reqMethod) {
c.logf("Preflight aborted: method '%s' not allowed", reqMethod)
return
}
reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
if !c.areHeadersAllowed(reqHeaders) {
c.logf("Preflight aborted: headers '%v' not allowed", reqHeaders)
return
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Access-Control-Allow-Origin", origin)
}
// Spec says: Since the list of methods can be unbounded, simply returning the method indicated
// by Access-Control-Request-Method (if supported) can be enough
headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
if len(reqHeaders) > 0 {
// Spec says: Since the list of headers can be unbounded, simply returning supported headers
// from Access-Control-Request-Headers can be enough
headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
if c.maxAge > 0 {
headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
}
c.logf("Preflight response headers: %v", headers)
}
// handleActualRequest handles simple cross-origin requests, actual request or redirects
func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
origin := r.Header.Get("Origin")
// Always set Vary, see https://github.com/rs/cors/issues/10
headers.Add("Vary", "Origin")
if origin == "" {
c.logf("Actual request no headers added: missing origin")
return
}
if !c.isOriginAllowed(r, origin) {
c.logf("Actual request no headers added: origin '%s' not allowed", origin)
return
}
// Note that spec does define a way to specifically disallow a simple method like GET or
// POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
// spec doesn't instruct to check the allowed methods for simple cross-origin requests.
// We think it's a nice feature to be able to have control on those methods though.
if !c.isMethodAllowed(r.Method) {
c.logf("Actual request no headers added: method '%s' not allowed", r.Method)
return
}
if c.allowedOriginsAll {
headers.Set("Access-Control-Allow-Origin", "*")
} else {
headers.Set("Access-Control-Allow-Origin", origin)
}
if len(c.exposedHeaders) > 0 {
headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
}
if c.allowCredentials {
headers.Set("Access-Control-Allow-Credentials", "true")
}
c.logf("Actual response added headers: %v", headers)
}
// convenience method. checks if a logger is set.
func (c *Cors) logf(format string, a ...interface{}) {
if c.Log != nil {
c.Log.Printf(format, a...)
}
}
// isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
// on the endpoint
func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool {
if c.allowOriginFunc != nil {
return c.allowOriginFunc(r, origin)
}
if c.allowedOriginsAll {
return true
}
origin = strings.ToLower(origin)
for _, o := range c.allowedOrigins {
if o == origin {
return true
}
}
for _, w := range c.allowedWOrigins {
if w.match(origin) {
return true
}
}
return false
}
// isMethodAllowed checks if a given method can be used as part of a cross-domain request
// on the endpoint
func (c *Cors) isMethodAllowed(method string) bool {
if len(c.allowedMethods) == 0 {
// If no method allowed, always return false, even for preflight request
return false
}
method = strings.ToUpper(method)
if method == http.MethodOptions {
// Always allow preflight requests
return true
}
for _, m := range c.allowedMethods {
if m == method {
return true
}
}
return false
}
// areHeadersAllowed checks if a given list of headers are allowed to used within
// a cross-domain request.
func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
if c.allowedHeadersAll || len(requestedHeaders) == 0 {
return true
}
for _, header := range requestedHeaders {
header = http.CanonicalHeaderKey(header)
found := false
for _, h := range c.allowedHeaders {
if h == header {
found = true
break
}
}
if !found {
return false
}
}
return true
}

70
vendor/github.com/go-chi/cors/utils.go generated vendored Normal file
View File

@@ -0,0 +1,70 @@
package cors
import "strings"
const toLower = 'a' - 'A'
type converter func(string) string
type wildcard struct {
prefix string
suffix string
}
func (w wildcard) match(s string) bool {
return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix)
}
// convert converts a list of string using the passed converter function
func convert(s []string, c converter) []string {
out := []string{}
for _, i := range s {
out = append(out, c(i))
}
return out
}
// parseHeaderList tokenize + normalize a string containing a list of headers
func parseHeaderList(headerList string) []string {
l := len(headerList)
h := make([]byte, 0, l)
upper := true
// Estimate the number headers in order to allocate the right splice size
t := 0
for i := 0; i < l; i++ {
if headerList[i] == ',' {
t++
}
}
headers := make([]string, 0, t)
for i := 0; i < l; i++ {
b := headerList[i]
if b >= 'a' && b <= 'z' {
if upper {
h = append(h, b-toLower)
} else {
h = append(h, b)
}
} else if b >= 'A' && b <= 'Z' {
if !upper {
h = append(h, b+toLower)
} else {
h = append(h, b)
}
} else if b == '-' || b == '_' || b == '.' || (b >= '0' && b <= '9') {
h = append(h, b)
}
if b == ' ' || b == ',' || i == l-1 {
if len(h) > 0 {
// Flush the found header
headers = append(headers, string(h))
h = h[:0]
upper = true
}
} else {
upper = b == '-'
}
}
return headers
}

5
vendor/github.com/gocql/gocql/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,5 @@
gocql-fuzz
fuzz-corpus
fuzz-work
gocql.test
.idea

148
vendor/github.com/gocql/gocql/AUTHORS generated vendored Normal file
View File

@@ -0,0 +1,148 @@
# This source file refers to The gocql Authors for copyright purposes.
Christoph Hack <christoph@tux21b.org>
Jonathan Rudenberg <jonathan@titanous.com>
Thorsten von Eicken <tve@rightscale.com>
Matt Robenolt <mattr@disqus.com>
Phillip Couto <phillip.couto@stemstudios.com>
Niklas Korz <korz.niklask@gmail.com>
Nimi Wariboko Jr <nimi@channelmeter.com>
Ghais Issa <ghais.issa@gmail.com>
Sasha Klizhentas <klizhentas@gmail.com>
Konstantin Cherkasov <k.cherkasoff@gmail.com>
Ben Hood <0x6e6562@gmail.com>
Pete Hopkins <phopkins@gmail.com>
Chris Bannister <c.bannister@gmail.com>
Maxim Bublis <b@codemonkey.ru>
Alex Zorin <git@zor.io>
Kasper Middelboe Petersen <me@phant.dk>
Harpreet Sawhney <harpreet.sawhney@gmail.com>
Charlie Andrews <charlieandrews.cwa@gmail.com>
Stanislavs Koikovs <stanislavs.koikovs@gmail.com>
Dan Forest <bonjour@dan.tf>
Miguel Serrano <miguelvps@gmail.com>
Stefan Radomski <gibheer@zero-knowledge.org>
Josh Wright <jshwright@gmail.com>
Jacob Rhoden <jacob.rhoden@gmail.com>
Ben Frye <benfrye@gmail.com>
Fred McCann <fred@sharpnoodles.com>
Dan Simmons <dan@simmons.io>
Muir Manders <muir@retailnext.net>
Sankar P <sankar.curiosity@gmail.com>
Julien Da Silva <julien.dasilva@gmail.com>
Dan Kennedy <daniel@firstcs.co.uk>
Nick Dhupia<nick.dhupia@gmail.com>
Yasuharu Goto <matope.ono@gmail.com>
Jeremy Schlatter <jeremy.schlatter@gmail.com>
Matthias Kadenbach <matthias.kadenbach@gmail.com>
Dean Elbaz <elbaz.dean@gmail.com>
Mike Berman <evencode@gmail.com>
Dmitriy Fedorenko <c0va23@gmail.com>
Zach Marcantel <zmarcantel@gmail.com>
James Maloney <jamessagan@gmail.com>
Ashwin Purohit <purohit@gmail.com>
Dan Kinder <dkinder.is.me@gmail.com>
Oliver Beattie <oliver@obeattie.com>
Justin Corpron <jncorpron@gmail.com>
Miles Delahunty <miles.delahunty@gmail.com>
Zach Badgett <zach.badgett@gmail.com>
Maciek Sakrejda <maciek@heroku.com>
Jeff Mitchell <jeffrey.mitchell@gmail.com>
Baptiste Fontaine <b@ptistefontaine.fr>
Matt Heath <matt@mattheath.com>
Jamie Cuthill <jamie.cuthill@gmail.com>
Adrian Casajus <adriancasajus@gmail.com>
John Weldon <johnweldon4@gmail.com>
Adrien Bustany <adrien@bustany.org>
Andrey Smirnov <smirnov.andrey@gmail.com>
Adam Weiner <adamsweiner@gmail.com>
Daniel Cannon <daniel@danielcannon.co.uk>
Johnny Bergström <johnny@joonix.se>
Adriano Orioli <orioli.adriano@gmail.com>
Claudiu Raveica <claudiu.raveica@gmail.com>
Artem Chernyshev <artem.0xD2@gmail.com>
Ference Fu <fym201@msn.com>
LOVOO <opensource@lovoo.com>
nikandfor <nikandfor@gmail.com>
Anthony Woods <awoods@raintank.io>
Alexander Inozemtsev <alexander.inozemtsev@gmail.com>
Rob McColl <rob@robmccoll.com>; <rmccoll@ionicsecurity.com>
Viktor Tönköl <viktor.toenkoel@motionlogic.de>
Ian Lozinski <ian.lozinski@gmail.com>
Michael Highstead <highstead@gmail.com>
Sarah Brown <esbie.is@gmail.com>
Caleb Doxsey <caleb@datadoghq.com>
Frederic Hemery <frederic.hemery@datadoghq.com>
Pekka Enberg <penberg@scylladb.com>
Mark M <m.mim95@gmail.com>
Bartosz Burclaf <burclaf@gmail.com>
Marcus King <marcusking01@gmail.com>
Andrew de Andrade <andrew@deandrade.com.br>
Robert Nix <robert@nicerobot.org>
Nathan Youngman <git@nathany.com>
Charles Law <charles.law@gmail.com>; <claw@conduce.com>
Nathan Davies <nathanjamesdavies@gmail.com>
Bo Blanton <bo.blanton@gmail.com>
Vincent Rischmann <me@vrischmann.me>
Jesse Claven <jesse.claven@gmail.com>
Derrick Wippler <thrawn01@gmail.com>
Leigh McCulloch <leigh@leighmcculloch.com>
Ron Kuris <swcafe@gmail.com>
Raphael Gavache <raphael.gavache@gmail.com>
Yasser Abdolmaleki <yasser@yasser.ca>
Krishnanand Thommandra <devtkrishna@gmail.com>
Blake Atkinson <me@blakeatkinson.com>
Dharmendra Parsaila <d4dharmu@gmail.com>
Nayef Ghattas <nayef.ghattas@datadoghq.com>
Michał Matczuk <mmatczuk@gmail.com>
Ben Krebsbach <ben.krebsbach@gmail.com>
Vivian Mathews <vivian.mathews.3@gmail.com>
Sascha Steinbiss <satta@debian.org>
Seth Rosenblum <seth.t.rosenblum@gmail.com>
Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
Luke Hines <lukehines@protonmail.com>
Zhixin Wen <john.wenzhixin@hotmail.com>
Chang Liu <changliu.it@gmail.com>
Ingo Oeser <nightlyone@gmail.com>
Luke Hines <lukehines@protonmail.com>
Jacob Greenleaf <jacob@jacobgreenleaf.com>
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
Marco Cadetg <cadetg@gmail.com>
Karl Matthias <karl@matthias.org>
Thomas Meson <zllak@hycik.org>
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
Pavel Buchinchik <p.buchinchik@gmail.com>
Rintaro Okamura <rintaro.okamura@gmail.com>
Ivan Boyarkin <ivan.boyarkin@kiwi.com>; <mr.vanboy@gmail.com>
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
Jorge Bay <jorgebg@apache.org>
Dmitriy Kozlov <hummerd@mail.ru>
Alexey Romanovsky <alexus1024+gocql@gmail.com>
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
Piotr Dulikowski <piodul@scylladb.com>
Árni Dagur <arni@dagur.eu>
Tushar Das <tushar.das5@gmail.com>
Maxim Vladimirskiy <horkhe@gmail.com>
Bogdan-Ciprian Rusu <bogdanciprian.rusu@crowdstrike.com>
Yuto Doi <yutodoi.seattle@gmail.com>
Krishna Vadali <tejavadali@gmail.com>
Jens-W. Schicke-Uffmann <drahflow@gmx.de>
Ondrej Polakovič <ondrej.polakovic@kiwi.com>
Sergei Karetnikov <sergei.karetnikov@gmail.com>
Stefan Miklosovic <smiklosovic@apache.org>
Adam Burk <amburk@gmail.com>
Valerii Ponomarov <kiparis.kh@gmail.com>
Neal Turett <neal.turett@datadoghq.com>
Doug Schaapveld <djschaap@gmail.com>
Steven Seidman <steven.seidman@datadoghq.com>
Wojciech Przytuła <wojciech.przytula@scylladb.com>
João Reis <joao.reis@datastax.com>
Lauro Ramos Venancio <lauro.venancio@incognia.com>
Dmitry Kropachev <dmitry.kropachev@gmail.com>
Oliver Boyle <pleasedontspamme4321+gocql@gmail.com>
Jackson Fleming <jackson.fleming@instaclustr.com>
Sylwia Szunejko <sylwia.szunejko@scylladb.com>
Karol Baryła <karol.baryla@scylladb.com>
Marcin Mazurek <marcinek.mazurek@gmail.com>
Moguchev Leonid Alekseevich <lmoguchev@ozon.ru>
Julien Lefevre <julien.lefevr@gmail.com>

78
vendor/github.com/gocql/gocql/CONTRIBUTING.md generated vendored Normal file
View File

@@ -0,0 +1,78 @@
# Contributing to gocql
**TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql.
This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base.
## Background
The goal of the gocql project is to provide a stable and robust CQL driver for Go. gocql is a community driven project that is coordinated by a small team of core developers.
## Minimum Requirement Checklist
The following is a check list of requirements that need to be satisfied in order for us to merge your patch:
* You should raise a pull request to gocql/gocql on Github
* The pull request has a title that clearly summarizes the purpose of the patch
* The motivation behind the patch is clearly defined in the pull request summary
* Your name and email have been added to the `AUTHORS` file (for copyright purposes)
* The patch will merge cleanly
* The test coverage does not fall below the critical threshold (currently 64%)
* The merge commit passes the regression test suite on Travis
* `go fmt` has been applied to the submitted code
* Notable changes (i.e. new features or changed behavior, bugfixes) are appropriately documented in CHANGELOG.md, functional changes also in godoc
If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements.
## Beyond The Checklist
In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss.
### General QA Approach
The gocql team needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted.
Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce.
Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis.
That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there.
### When It's Too Difficult To Automate Testing
There are legitimate examples of where it is infeasible to write a regression test for a change. Never fear, we will still consider the patch and quite possibly accept the change without a test. The gocql team takes a pragmatic approach to testing. At the end of the day, you could be addressing an issue that is too difficult to reproduce in a test suite, but still occurs in a real production app. In this case, your production app is the test case, and we will have to trust that your change is good.
Examples of pull requests that have been accepted without tests include:
* https://github.com/gocql/gocql/pull/181 - this patch would otherwise require a multi-node cluster to be booted as part of the CI build
* https://github.com/gocql/gocql/pull/179 - this bug can only be reproduced under heavy load in certain circumstances
### Sign Off Procedure
Generally speaking, a pull request can get merged by any one of the core gocql team. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process.
### Supported Features
gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation with the gocql team. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way.
### Longer Term Evolution
There are some long term plans for gocql that have to be taken into account when assessing changes. That said, gocql is ultimately a community driven project and we don't have a massive development budget, so sometimes the long term view might need to be de-prioritized ahead of short term changes.
## Officially Supported Server Versions
Currently, the officially supported versions of the Cassandra server include:
* 1.2.18
* 2.0.9
Chances are that gocql will work with many other versions. If you would like us to support a particular version of Cassandra, please start a conversation about what version you'd like us to consider. We are more likely to accept a new version if you help out by extending the regression suite to cover the new version to be supported.
## The Core Dev Team
The core development team includes:
* tux21b
* phillipCouto
* Zariel
* 0x6e6562

27
vendor/github.com/gocql/gocql/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2016, The Gocql authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

5
vendor/github.com/gocql/gocql/Makefile generated vendored Normal file
View File

@@ -0,0 +1,5 @@
# Makefile to run the Docker cleanup script
clean-old-temporary-docker-images:
@echo "Running Docker Hub image cleanup script..."
python ci/clean-old-temporary-docker-images.py

242
vendor/github.com/gocql/gocql/README.md generated vendored Normal file
View File

@@ -0,0 +1,242 @@
<div align="center">
![Build Passing](https://github.com/scylladb/gocql/workflows/Build/badge.svg)
[![Read the Fork Driver Docs](https://img.shields.io/badge/Read_the_Docs-pkg_go-blue)](https://pkg.go.dev/github.com/scylladb/gocql#section-documentation)
[![Protocol Specs](https://img.shields.io/badge/Protocol_Specs-ScyllaDB_Docs-blue)](https://github.com/scylladb/scylladb/blob/master/docs/dev/protocol-extensions.md)
</div>
<h1 align="center">
Scylla Shard-Aware Fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver)
</h1>
<img src="./.github/assets/logo.svg" width="200" align="left" />
This is a fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver) package that we created at Scylla.
It contains extensions to tokenAwareHostPolicy supported by the Scylla 2.3 and onwards.
It allows driver to select a connection to a particular shard on a host based on the token.
This eliminates passing data between shards and significantly reduces latency.
There are open pull requests to merge the functionality to the upstream project:
* [gocql/gocql#1210](https://github.com/gocql/gocql/pull/1210)
* [gocql/gocql#1211](https://github.com/gocql/gocql/pull/1211).
It also provides support for shard aware ports, a faster way to connect to all shards, details available in [blogpost](https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/).
---
### Table of Contents
- [1. Sunsetting Model](#1-sunsetting-model)
- [2. Installation](#2-installation)
- [3. Quick Start](#3-quick-start)
- [4. Data Types](#4-data-types)
- [5. Configuration](#5-configuration)
- [5.1 Shard-aware port](#51-shard-aware-port)
- [5.2 Iterator](#52-iterator)
- [6. Contributing](#6-contributing)
## 1. Sunsetting Model
> [!WARNING]
> In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset.
## 2. Installation
This is a drop-in replacement to gocql, it reuses the `github.com/gocql/gocql` import path.
Add the following line to your project `go.mod` file.
```mod
replace github.com/gocql/gocql => github.com/scylladb/gocql latest
```
and run
```sh
go mod tidy
```
to evaluate `latest` to a concrete tag.
Your project now uses the Scylla driver fork, make sure you are using the `TokenAwareHostPolicy` to enable the shard-awareness, continue reading for details.
## 3. Quick Start
Spawn a ScyllaDB Instance using Docker Run command:
```sh
docker run --name node1 --network your-network -p "9042:9042" -d scylladb/scylla:6.1.2 \
--overprovisioned 1 \
--smp 1
```
Then, create a new connection using ScyllaDB GoCQL following the example below:
```go
package main
import (
"fmt"
"github.com/gocql/gocql"
)
func main() {
var cluster = gocql.NewCluster("localhost:9042")
var session, err = cluster.CreateSession()
if err != nil {
panic("Failed to connect to cluster")
}
defer session.Close()
var query = session.Query("SELECT * FROM system.clients")
if rows, err := query.Iter().SliceMap(); err == nil {
for _, row := range rows {
fmt.Printf("%v\n", row)
}
} else {
panic("Query error: " + err.Error())
}
}
```
## 4. Data Types
Here's an list of all ScyllaDB Types reflected in the GoCQL environment:
| ScyllaDB Type | Go Type |
| ---------------- | ------------------ |
| `ascii` | `string` |
| `bigint` | `int64` |
| `blob` | `[]byte` |
| `boolean` | `bool` |
| `date` | `time.Time` |
| `decimal` | `inf.Dec` |
| `double` | `float64` |
| `duration` | `gocql.Duration` |
| `float` | `float32` |
| `uuid` | `gocql.UUID` |
| `int` | `int32` |
| `inet` | `string` |
| `list<int>` | `[]int32` |
| `map<int, text>` | `map[int32]string` |
| `set<int>` | `[]int32` |
| `smallint` | `int16` |
| `text` | `string` |
| `time` | `time.Duration` |
| `timestamp` | `time.Time` |
| `timeuuid` | `gocql.UUID` |
| `tinyint` | `int8` |
| `varchar` | `string` |
| `varint` | `int64` |
## 5. Configuration
In order to make shard-awareness work, token aware host selection policy has to be enabled.
Please make sure that the gocql configuration has `PoolConfig.HostSelectionPolicy` properly set like in the example below.
__When working with a Scylla cluster, `PoolConfig.NumConns` option has no effect - the driver opens one connection for each shard and completely ignores this option.__
```go
c := gocql.NewCluster(hosts...)
// Enable token aware host selection policy, if using multi-dc cluster set a local DC.
fallback := gocql.RoundRobinHostPolicy()
if localDC != "" {
fallback = gocql.DCAwareRoundRobinPolicy(localDC)
}
c.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(fallback)
// If using multi-dc cluster use the "local" consistency levels.
if localDC != "" {
c.Consistency = gocql.LocalQuorum
}
// When working with a Scylla cluster the driver always opens one connection per shard, so `NumConns` is ignored.
// c.NumConns = 4
```
### 5.1 Shard-aware port
This version of gocql supports a more robust method of establishing connection for each shard by using _shard aware port_ for native transport.
It greatly reduces time and the number of connections needed to establish a connection per shard in some cases - ex. when many clients connect at once, or when there are non-shard-aware clients connected to the same cluster.
If you are using a custom Dialer and if your nodes expose the shard-aware port, it is highly recommended to update it so that it uses a specific source port when connecting.
* If you are using a custom `net.Dialer`, you can make your dialer honor the source port by wrapping it in a `gocql.ScyllaShardAwareDialer`:
```go
oldDialer := net.Dialer{...}
clusterConfig.Dialer := &gocql.ScyllaShardAwareDialer{oldDialer}
```
* If you are using a custom type implementing `gocql.Dialer`, you can get the source port by using the `gocql.ScyllaGetSourcePort` function.
An example:
```go
func (d *myDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
sourcePort := gocql.ScyllaGetSourcePort(ctx)
localAddr, err := net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
if err != nil {
return nil, err
}
d := &net.Dialer{LocalAddr: localAddr}
return d.DialContext(ctx, network, addr)
}
```
The source port might be already bound by another connection on your system.
In such case, you should return an appropriate error so that the driver can retry with a different port suitable for the shard it tries to connect to.
* If you are using `net.Dialer.DialContext`, this function will return an error in case the source port is unavailable, and you can just return that error from your custom `Dialer`.
* Otherwise, if you detect that the source port is unavailable, you can either return `gocql.ErrScyllaSourcePortAlreadyInUse` or `syscall.EADDRINUSE`.
For this feature to work correctly, you need to make sure the following conditions are met:
* Your cluster nodes are configured to listen on the shard-aware port (`native_shard_aware_transport_port` option),
* Your cluster nodes are not behind a NAT which changes source ports,
* If you have a custom Dialer, it connects from the correct source port (see the guide above).
The feature is designed to gracefully fall back to the using the non-shard-aware port when it detects that some of the above conditions are not met.
The driver will print a warning about misconfigured address translation if it detects it.
Issues with shard-aware port not being reachable are not reported in non-debug mode, because there is no way to detect it without false positives.
If you suspect that this feature is causing you problems, you can completely disable it by setting the `ClusterConfig.DisableShardAwarePort` flag to true.
### 5.2 Iterator
Paging is a way to parse large result sets in smaller chunks.
The driver provides an iterator to simplify this process.
Use `Query.Iter()` to obtain iterator:
```go
iter := session.Query("SELECT id, value FROM my_table WHERE id > 100 AND id < 10000").Iter()
var results []int
var id, value int
for !iter.Scan(&id, &value) {
if id%2 == 0 {
results = append(results, value)
}
}
if err := iter.Close(); err != nil {
// handle error
}
```
In case of range and `ALLOW FILTERING` queries server can send empty responses for some pages.
That is why you should never consider empty response as the end of the result set.
Always check `iter.Scan()` result to know if there are more results, or `Iter.LastPage()` to know if the last page was reached.
## 6. Contributing
If you have any interest to be contributing in this GoCQL Fork, please read the [CONTRIBUTING.md](CONTRIBUTING.md) before initialize any Issue or Pull Request.

26
vendor/github.com/gocql/gocql/address_translators.go generated vendored Normal file
View File

@@ -0,0 +1,26 @@
package gocql
import "net"
// AddressTranslator provides a way to translate node addresses (and ports) that are
// discovered or received as a node event. This can be useful in an ec2 environment,
// for instance, to translate public IPs to private IPs.
type AddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int) (net.IP, int)
}
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
return fn(addr, port)
}
// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
func IdentityTranslator() AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return addr, port
})
}

541
vendor/github.com/gocql/gocql/cluster.go generated vendored Normal file
View File

@@ -0,0 +1,541 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"sync/atomic"
"time"
)
const defaultDriverName = "ScyllaDB GoCQL Driver"
// PoolConfig configures the connection pool used by the driver, it defaults to
// using a round-robin host selection policy and a round-robin connection selection
// policy for each host.
type PoolConfig struct {
// HostSelectionPolicy sets the policy for selecting which host to use for a
// given query (default: RoundRobinHostPolicy())
// It is not supported to use a single HostSelectionPolicy in multiple sessions
// (even if you close the old session before using in a new session).
HostSelectionPolicy HostSelectionPolicy
}
func (p PoolConfig) buildPool(session *Session) *policyConnPool {
return newPolicyConnPool(session)
}
// ClusterConfig is a struct to configure the default cluster implementation
// of gocql. It has a variety of attributes that can be used to modify the
// behavior to fit the most common use cases. Applications that require a
// different setup must implement their own cluster.
type ClusterConfig struct {
// addresses for the initial connections. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
Hosts []string
// CQL version (default: 3.0.0)
CQLVersion string
// ProtoVersion sets the version of the native protocol to use, this will
// enable features in the driver for specific protocol versions, generally this
// should be set to a known version (2,3,4) for the cluster being connected to.
//
// If it is 0 or unset (the default) then the driver will attempt to discover the
// highest supported protocol for the cluster. In clusters with nodes of different
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
ProtoVersion int
// Timeout limits the time spent on the client side while executing a query.
// Specifically, query or batch execution will return an error if the client does not receive a response
// from the server within the Timeout period.
// Timeout is also used to configure the read timeout on the underlying network connection.
// Client Timeout should always be higher than the request timeouts configured on the server,
// so that retries don't overload the server.
// Timeout has a default value of 11 seconds, which is higher than default server timeout for most query types.
// Timeout is not applied to requests during initial connection setup, see ConnectTimeout.
Timeout time.Duration
// ConnectTimeout limits the time spent during connection setup.
// During initial connection setup, internal queries, AUTH requests will return an error if the client
// does not receive a response within the ConnectTimeout period.
// ConnectTimeout is applied to the connection setup queries independently.
// ConnectTimeout also limits the duration of dialing a new TCP connection
// in case there is no Dialer nor HostDialer configured.
// ConnectTimeout has a default value of 11 seconds.
ConnectTimeout time.Duration
// WriteTimeout limits the time the driver waits to write a request to a network connection.
// WriteTimeout should be lower than or equal to Timeout.
// WriteTimeout defaults to the value of Timeout.
WriteTimeout time.Duration
// Port used when dialing.
// Default: 9042
Port int
// Initial keyspace. Optional.
Keyspace string
// The size of the connection pool for each host.
// The pool filling runs in separate gourutine during the session initialization phase.
// gocql will always try to get 1 connection on each host pool
// during session initialization AND it will attempt
// to fill each pool afterward asynchronously if NumConns > 1.
// Notice: There is no guarantee that pool filling will be finished in the initialization phase.
// Also, it describes a maximum number of connections at the same time.
// Default: 2
NumConns int
// Maximum number of inflight requests allowed per connection.
// Default: 32768 for CQL v3 and newer
// Default: 128 for older CQL versions
MaxRequestsPerConn int
// Default consistency level.
// Default: Quorum
Consistency Consistency
// Compression algorithm.
// Default: nil
Compressor Compressor
// Default: nil
Authenticator Authenticator
WarningsHandlerBuilder WarningHandlerBuilder
// An Authenticator factory. Can be used to create alternative authenticators.
// Default: nil
AuthProvider func(h *HostInfo) (Authenticator, error)
// Default retry policy to use for queries.
// Default: no retries.
RetryPolicy RetryPolicy
// ConvictionPolicy decides whether to mark host as down based on the error and host info.
// Default: SimpleConvictionPolicy
ConvictionPolicy ConvictionPolicy
// Default reconnection policy to use for reconnecting before trying to mark host as down.
ReconnectionPolicy ReconnectionPolicy
// A reconnection policy to use for reconnecting when connecting to the cluster first time.
InitialReconnectionPolicy ReconnectionPolicy
// The keepalive period to use, enabled if > 0 (default: 15 seconds)
// SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
SocketKeepalive time.Duration
// Maximum cache size for prepared statements globally for gocql.
// Default: 1000
MaxPreparedStmts int
// Maximum cache size for query info about statements for each session.
// Default: 1000
MaxRoutingKeyInfo int
// Default page size to use for created sessions.
// Default: 5000
PageSize int
// Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL.
// Default: unset
SerialConsistency SerialConsistency
// SslOpts configures TLS use when HostDialer is not set.
// SslOpts is ignored if HostDialer is set.
SslOpts *SslOptions
actualSslOpts atomic.Value
// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
// Default: true, only enabled for protocol 3 and above.
DefaultTimestamp bool
// The name of the driver that is going to be reported to the server.
// Default: "ScyllaDB GoLang Driver"
DriverName string
// The version of the driver that is going to be reported to the server.
// Defaulted to current library version
DriverVersion string
// PoolConfig configures the underlying connection pool, allowing the
// configuration of host selection and connection selection policies.
PoolConfig PoolConfig
// If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval.
ReconnectInterval time.Duration
// The maximum amount of time to wait for schema agreement in a cluster after
// receiving a schema change frame. (default: 60s)
MaxWaitSchemaAgreement time.Duration
// HostFilter will filter all incoming events for host, any which don't pass
// the filter will be ignored. If set will take precedence over any options set
// via Discovery
HostFilter HostFilter
// AddressTranslator will translate addresses found on peer discovery and/or
// node change events.
AddressTranslator AddressTranslator
// If IgnorePeerAddr is true and the address in system.peers does not match
// the supplied host by either initial hosts or discovered via events then the
// host will be replaced with the supplied address.
//
// For example if an event comes in with host=10.0.0.1 but when looking up that
// address in system.local or system.peers returns 127.0.0.1, the peer will be
// set to 10.0.0.1 which is what will be used to connect to.
IgnorePeerAddr bool
// If DisableInitialHostLookup then the driver will not attempt to get host info
// from the system.peers table, this will mean that the driver will connect to
// hosts supplied and will not attempt to lookup the hosts information, this will
// mean that data_centre, rack and token information will not be available and as
// such host filtering and token aware query routing will not be available.
DisableInitialHostLookup bool
// Configure events the driver will register for
Events struct {
// disable registering for status events (node up/down)
DisableNodeStatusEvents bool
// disable registering for topology events (node added/removed/moved)
DisableTopologyEvents bool
// disable registering for schema events (keyspace/table/function removed/created/updated)
DisableSchemaEvents bool
}
// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
// send skip_metadata for queries, this means that the result will always contain
// the metadata to parse the rows and will not reuse the metadata from the prepared
// statement.
//
// See https://issues.apache.org/jira/browse/CASSANDRA-10786
// See https://github.com/scylladb/scylladb/issues/20860
//
// Default: true
DisableSkipMetadata bool
// QueryObserver will set the provided query observer on all queries created from this session.
// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
QueryObserver QueryObserver
// BatchObserver will set the provided batch observer on all queries created from this session.
// Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver.
BatchObserver BatchObserver
// ConnectObserver will set the provided connect observer on all queries
// created from this session.
ConnectObserver ConnectObserver
// FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session.
// Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver.
FrameHeaderObserver FrameHeaderObserver
// StreamObserver will be notified of stream state changes.
// This can be used to track in-flight protocol requests and responses.
StreamObserver StreamObserver
// Default idempotence for queries
DefaultIdempotence bool
// The time to wait for frames before flushing the frames connection to Cassandra.
// Can help reduce syscall overhead by making less calls to write. Set to 0 to
// disable.
//
// (default: 200 microseconds)
WriteCoalesceWaitTime time.Duration
// Dialer will be used to establish all connections created for this Cluster.
// If not provided, a default dialer configured with ConnectTimeout will be used.
// Dialer is ignored if HostDialer is provided.
Dialer Dialer
// HostDialer will be used to establish all connections for this Cluster.
// Unlike Dialer, HostDialer is responsible for setting up the entire connection, including the TLS session.
// To support shard-aware port, HostDialer should implement ShardDialer.
// If not provided, Dialer will be used instead.
HostDialer HostDialer
// DisableShardAwarePort will prevent the driver from connecting to Scylla's shard-aware port,
// even if there are nodes in the cluster that support it.
//
// It is generally recommended to leave this option turned off because gocql can use
// the shard-aware port to make the process of establishing more robust.
// However, if you have a cluster with nodes which expose shard-aware port
// but the port is unreachable due to network configuration issues, you can use
// this option to work around the issue. Set it to true only if you neither can fix
// your network nor disable shard-aware port on your nodes.
DisableShardAwarePort bool
// Logger for this ClusterConfig.
// If not specified, defaults to the global gocql.Logger.
Logger StdLogger
// The timeout for the requests to the schema tables. (default: 60s)
MetadataSchemaRequestTimeout time.Duration
// internal config for testing
disableControlConn bool
disableInit bool
}
type Dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
// NewCluster generates a new config for the default cluster implementation.
//
// The supplied hosts are used to initially connect to the cluster then the rest of
// the ring will be automatically discovered. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
func NewCluster(hosts ...string) *ClusterConfig {
cfg := &ClusterConfig{
Hosts: hosts,
CQLVersion: "3.0.0",
Timeout: 11 * time.Second,
ConnectTimeout: 11 * time.Second,
Port: 9042,
NumConns: 2,
Consistency: Quorum,
MaxPreparedStmts: defaultMaxPreparedStmts,
MaxRoutingKeyInfo: 1000,
PageSize: 5000,
DefaultTimestamp: true,
DriverName: defaultDriverName,
DriverVersion: defaultDriverVersion,
MaxWaitSchemaAgreement: 60 * time.Second,
ReconnectInterval: 60 * time.Second,
ConvictionPolicy: &SimpleConvictionPolicy{},
ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second},
InitialReconnectionPolicy: &NoReconnectionPolicy{},
SocketKeepalive: 15 * time.Second,
WriteCoalesceWaitTime: 200 * time.Microsecond,
MetadataSchemaRequestTimeout: 60 * time.Second,
DisableSkipMetadata: true,
WarningsHandlerBuilder: DefaultWarningHandlerBuilder,
}
return cfg
}
func (cfg *ClusterConfig) logger() StdLogger {
if cfg.Logger == nil {
return Logger
}
return cfg.Logger
}
// CreateSession initializes the cluster based on this config and returns a
// session object that can be used to interact with the database.
func (cfg *ClusterConfig) CreateSession() (*Session, error) {
return NewSession(*cfg)
}
func (cfg *ClusterConfig) CreateSessionNonBlocking() (*Session, error) {
return NewSessionNonBlocking(*cfg)
}
// translateAddressPort is a helper method that will use the given AddressTranslator
// if defined, to translate the given address and port into a possibly new address
// and port, If no AddressTranslator or if an error occurs, the given address and
// port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
if cfg.AddressTranslator == nil || len(addr) == 0 {
return addr, port
}
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
if gocqlDebug {
cfg.logger().Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
}
return newAddr, newPort
}
func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
}
func (cfg *ClusterConfig) ValidateAndInitSSL() error {
if cfg.SslOpts == nil {
return nil
}
actualTLSConfig, err := setupTLSConfig(cfg.SslOpts)
if err != nil {
return fmt.Errorf("failed to initialize ssl configuration: %s", err.Error())
}
cfg.actualSslOpts.Store(actualTLSConfig)
return nil
}
func (cfg *ClusterConfig) getActualTLSConfig() *tls.Config {
val, ok := cfg.actualSslOpts.Load().(*tls.Config)
if !ok {
return nil
}
return val.Clone()
}
func (cfg *ClusterConfig) Validate() error {
if len(cfg.Hosts) == 0 {
return ErrNoHosts
}
if cfg.Authenticator != nil && cfg.AuthProvider != nil {
return errors.New("Can't use both Authenticator and AuthProvider in cluster config.")
}
if cfg.InitialReconnectionPolicy == nil {
return errors.New("InitialReconnectionPolicy is nil")
}
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("InitialReconnectionPolicy.GetMaxRetries returns negative number")
}
if cfg.ReconnectionPolicy == nil {
return errors.New("ReconnectionPolicy is nil")
}
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("ReconnectionPolicy.GetMaxRetries returns negative number")
}
if cfg.PageSize < 0 {
return errors.New("PageSize should be positive number or zero")
}
if cfg.MaxRoutingKeyInfo < 0 {
return errors.New("MaxRoutingKeyInfo should be positive number or zero")
}
if cfg.MaxPreparedStmts < 0 {
return errors.New("MaxPreparedStmts should be positive number or zero")
}
if cfg.SocketKeepalive < 0 {
return errors.New("SocketKeepalive should be positive time.Duration or zero")
}
if cfg.MaxRequestsPerConn < 0 {
return errors.New("MaxRequestsPerConn should be positive number or zero")
}
if cfg.NumConns < 0 {
return errors.New("NumConns should be positive non-zero number or zero")
}
if cfg.Port <= 0 || cfg.Port > 65535 {
return errors.New("Port should be a valid port number: a number between 1 and 65535")
}
if cfg.WriteTimeout < 0 {
return errors.New("WriteTimeout should be positive time.Duration or zero")
}
if cfg.Timeout < 0 {
return errors.New("Timeout should be positive time.Duration or zero")
}
if cfg.ConnectTimeout < 0 {
return errors.New("ConnectTimeout should be positive time.Duration or zero")
}
if cfg.MetadataSchemaRequestTimeout < 0 {
return errors.New("MetadataSchemaRequestTimeout should be positive time.Duration or zero")
}
if cfg.WriteCoalesceWaitTime < 0 {
return errors.New("WriteCoalesceWaitTime should be positive time.Duration or zero")
}
if cfg.ReconnectInterval < 0 {
return errors.New("ReconnectInterval should be positive time.Duration or zero")
}
if cfg.MaxWaitSchemaAgreement < 0 {
return errors.New("MaxWaitSchemaAgreement should be positive time.Duration or zero")
}
if cfg.ProtoVersion < 0 {
return errors.New("ProtoVersion should be positive number or zero")
}
if !cfg.DisableSkipMetadata {
Logger.Println("warning: enabling skipping metadata can lead to unpredictible results when executing query and altering columns involved in the query.")
}
return cfg.ValidateAndInitSSL()
}
var (
ErrNoHosts = errors.New("no hosts provided")
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
ErrHostQueryFailed = errors.New("unable to populate Hosts")
)
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil {
tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
}
// ca cert is optional
if sslOpts.CaPath != "" {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("unable to open CA certs: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("failed parsing or CA certs")
}
}
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("unable to load X509 key pair: %v", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
}
return tlsConfig, nil
}

29
vendor/github.com/gocql/gocql/compressor.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
package gocql
import (
"github.com/klauspost/compress/s2"
)
type Compressor interface {
Name() string
Encode(data []byte) ([]byte, error)
Decode(data []byte) ([]byte, error)
}
// SnappyCompressor implements the Compressor interface and can be used to
// compress incoming and outgoing frames. It uses S2 compression algorithm
// that is compatible with snappy and aims for high throughput, which is why
// it features concurrent compression for bigger payloads.
type SnappyCompressor struct{}
func (s SnappyCompressor) Name() string {
return "snappy"
}
func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
return s2.EncodeSnappy(nil, data), nil
}
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
return s2.Decode(nil, data)
}

1964
vendor/github.com/gocql/gocql/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

562
vendor/github.com/gocql/gocql/connectionpool.go generated vendored Normal file
View File

@@ -0,0 +1,562 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/gocql/gocql/debounce"
)
// interface to implement to receive the host information
type SetHosts interface {
SetHosts(hosts []*HostInfo)
}
// interface to implement to receive the partitioner value
type SetPartitioner interface {
SetPartitioner(partitioner string)
}
// interface to implement to receive the tablets value
type SetTablets interface {
SetTablets(tablets TabletInfoList)
}
type policyConnPool struct {
session *Session
port int
numConns int
keyspace string
mu sync.RWMutex
hostConnPools map[string]*hostConnPool
}
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
hostDialer := cfg.HostDialer
if hostDialer == nil {
dialer := cfg.Dialer
if dialer == nil {
d := net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.SocketKeepalive > 0 {
d.KeepAlive = cfg.SocketKeepalive
}
dialer = &ScyllaShardAwareDialer{d}
}
hostDialer = &scyllaDialer{
dialer: dialer,
logger: cfg.logger(),
tlsConfig: cfg.getActualTLSConfig(),
cfg: cfg,
}
}
return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
HostDialer: hostDialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
tlsConfig: cfg.getActualTLSConfig(),
}, nil
}
func newPolicyConnPool(session *Session) *policyConnPool {
// create the pool
pool := &policyConnPool{
session: session,
port: session.cfg.Port,
numConns: session.cfg.NumConns,
keyspace: session.cfg.Keyspace,
hostConnPools: map[string]*hostConnPool{},
}
return pool
}
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
p.mu.Lock()
defer p.mu.Unlock()
toRemove := make(map[string]struct{})
for hostID := range p.hostConnPools {
toRemove[hostID] = struct{}{}
}
pools := make(chan *hostConnPool)
createCount := 0
for _, host := range hosts {
if !host.IsUp() {
// don't create a connection pool for a down host
continue
}
hostID := host.HostID()
if _, exists := p.hostConnPools[hostID]; exists {
// still have this host, so don't remove it
delete(toRemove, hostID)
continue
}
createCount++
go func(host *HostInfo) {
// create a connection pool for the host
pools <- newHostConnPool(
p.session,
host,
p.port,
p.numConns,
p.keyspace,
)
}(host)
}
// add created pools
for createCount > 0 {
pool := <-pools
createCount--
if pool.Size() > 0 {
// add pool only if there a connections available
p.hostConnPools[pool.host.HostID()] = pool
}
}
for addr := range toRemove {
pool := p.hostConnPools[addr]
delete(p.hostConnPools, addr)
go pool.Close()
}
}
func (p *policyConnPool) InFlight() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.InFlight()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) Size() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.Size()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
hostID := host.HostID()
p.mu.RLock()
pool, ok = p.hostConnPools[hostID]
p.mu.RUnlock()
return
}
func (p *policyConnPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
// close the pools
for addr, pool := range p.hostConnPools {
delete(p.hostConnPools, addr)
pool.Close()
}
}
func (p *policyConnPool) addHost(host *HostInfo) {
hostID := host.HostID()
p.mu.Lock()
pool, ok := p.hostConnPools[hostID]
if !ok {
pool = newHostConnPool(
p.session,
host,
host.Port(), // TODO: if port == 0 use pool.port?
p.numConns,
p.keyspace,
)
p.hostConnPools[hostID] = pool
}
p.mu.Unlock()
pool.fill_debounce()
}
func (p *policyConnPool) removeHost(hostID string) {
p.mu.Lock()
pool, ok := p.hostConnPools[hostID]
if !ok {
p.mu.Unlock()
return
}
delete(p.hostConnPools, hostID)
p.mu.Unlock()
go pool.Close()
}
// hostConnPool is a connection pool for a single host.
// Connection selection is based on a provided ConnSelectionPolicy
type hostConnPool struct {
session *Session
host *HostInfo
size int
keyspace string
// protection for connPicker, closed, filling
mu sync.RWMutex
connPicker ConnPicker
closed bool
filling bool
debouncer *debounce.SimpleDebouncer
logger StdLogger
}
func (h *hostConnPool) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
size, _ := h.connPicker.Size()
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
h.filling, h.closed, size, h.size, h.host)
}
func newHostConnPool(session *Session, host *HostInfo, port, size int,
keyspace string) *hostConnPool {
pool := &hostConnPool{
session: session,
host: host,
size: size,
keyspace: keyspace,
connPicker: nopConnPicker{},
filling: false,
closed: false,
logger: session.logger,
debouncer: debounce.NewSimpleDebouncer(),
}
// the pool is not filled or connected
return pool
}
// Pick a connection from this connection pool for the given query.
func (pool *hostConnPool) Pick(token Token, qry ExecutableQuery) *Conn {
pool.mu.RLock()
defer pool.mu.RUnlock()
if pool.closed {
return nil
}
size, missing := pool.connPicker.Size()
if missing > 0 {
// try to fill the pool
go pool.fill_debounce()
if size == 0 {
return nil
}
}
return pool.connPicker.Pick(token, qry)
}
// Size returns the number of connections currently active in the pool
func (pool *hostConnPool) Size() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
size, _ := pool.connPicker.Size()
return size
}
// Size returns the number of connections currently active in the pool
func (pool *hostConnPool) InFlight() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
size := pool.connPicker.InFlight()
return size
}
// Close the connection pool
func (pool *hostConnPool) Close() {
pool.mu.Lock()
defer pool.mu.Unlock()
if !pool.closed {
pool.connPicker.Close()
}
pool.closed = true
}
// Fill the connection pool
func (pool *hostConnPool) fill() {
pool.mu.RLock()
// avoid filling a closed pool, or concurrent filling
if pool.closed || pool.filling {
pool.mu.RUnlock()
return
}
// determine the filling work to be done
startCount, fillCount := pool.connPicker.Size()
// avoid filling a full (or overfull) pool
if fillCount <= 0 {
pool.mu.RUnlock()
return
}
// switch from read to write lock
pool.mu.RUnlock()
pool.mu.Lock()
startCount, fillCount = pool.connPicker.Size()
if pool.closed || pool.filling || fillCount <= 0 {
// looks like another goroutine already beat this
// goroutine to the filling
pool.mu.Unlock()
return
}
// ok fill the pool
pool.filling = true
// allow others to access the pool while filling
pool.mu.Unlock()
// only this goroutine should make calls to fill/empty the pool at this
// point until after this routine or its subordinates calls
// fillingStopped
// fill only the first connection synchronously
if startCount == 0 {
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
// probably unreachable host
pool.fillingStopped(err)
return
}
// notify the session that this node is connected
go pool.session.handleNodeConnected(pool.host)
// filled one, let's reload it to see if it has changed
pool.mu.RLock()
_, fillCount = pool.connPicker.Size()
pool.mu.RUnlock()
}
// fill the rest of the pool asynchronously
go func() {
err := pool.connectMany(fillCount)
// mark the end of filling
pool.fillingStopped(err)
if err == nil && startCount > 0 {
// notify the session that this node is connected again
go pool.session.handleNodeConnected(pool.host)
}
}()
}
func (pool *hostConnPool) fill_debounce() {
pool.debouncer.Debounce(pool.fill)
}
func (pool *hostConnPool) logConnectErr(err error) {
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
// connection refused
// these are typical during a node outage so avoid log spam.
if gocqlDebug {
pool.logger.Printf("unable to dial %q: %v\n", pool.host, err)
}
} else if err != nil {
// unexpected error
pool.logger.Printf("error: failed to connect to %q due to error: %v", pool.host, err)
}
}
// transition back to a not-filling state.
func (pool *hostConnPool) fillingStopped(err error) {
if err != nil {
if gocqlDebug {
pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err)
}
// wait for some time to avoid back-to-back filling
// this provides some time between failed attempts
// to fill the pool for the host to recover
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
}
pool.mu.Lock()
pool.filling = false
count, _ := pool.connPicker.Size()
host := pool.host
port := pool.host.Port()
pool.mu.Unlock()
// if we errored and the size is now zero, make sure the host is marked as down
// see https://github.com/gocql/gocql/issues/1614
if gocqlDebug {
pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count)
}
if err != nil && count == 0 {
if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) {
pool.session.handleNodeDown(host.ConnectAddress(), port)
}
}
}
// connectMany creates new connections concurrent.
func (pool *hostConnPool) connectMany(count int) error {
if count == 0 {
return nil
}
var (
wg sync.WaitGroup
mu sync.Mutex
connectErr error
)
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
mu.Lock()
connectErr = err
mu.Unlock()
}
}()
}
// wait for all connections are done
wg.Wait()
return connectErr
}
// create a new connection to the host and add it to the pool
func (pool *hostConnPool) connect() (err error) {
pool.mu.Lock()
shardID, nrShards := pool.connPicker.NextShard()
pool.mu.Unlock()
// TODO: provide a more robust connection retry mechanism, we should also
// be able to detect hosts that come up by trying to connect to downed ones.
// try to connect
var conn *Conn
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
conn, err = pool.session.connectShard(pool.session.ctx, pool.host, pool, shardID, nrShards)
if err == nil {
break
}
if opErr, isOpErr := err.(*net.OpError); isOpErr {
// if the error is not a temporary error (ex: network unreachable) don't
// retry
if !opErr.Temporary() {
break
}
}
if gocqlDebug {
pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n",
pool.host.ConnectAddress(), err, reconnectionPolicy)
}
time.Sleep(reconnectionPolicy.GetInterval(i))
}
if err != nil {
return err
}
if pool.keyspace != "" {
// set the keyspace
if err = conn.UseKeyspace(pool.keyspace); err != nil {
conn.Close()
return err
}
}
// add the Conn to the pool
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
conn.Close()
return nil
}
// lazily initialize the connPicker when we know the required type
pool.initConnPicker(conn)
pool.connPicker.Put(conn)
return nil
}
func (pool *hostConnPool) initConnPicker(conn *Conn) {
if _, ok := pool.connPicker.(nopConnPicker); !ok {
return
}
if conn.isScyllaConn() {
pool.connPicker = newScyllaConnPicker(conn)
return
}
pool.connPicker = newDefaultConnPicker(pool.size)
}
// handle any error from a Conn
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
if !closed {
// still an open connection, so continue using it
return
}
// TODO: track the number of errors per host and detect when a host is dead,
// then also have something which can detect when a host comes back.
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
// pool closed
return
}
if gocqlDebug {
pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err)
}
pool.connPicker.Remove(conn)
go pool.fill_debounce()
}

140
vendor/github.com/gocql/gocql/connpicker.go generated vendored Normal file
View File

@@ -0,0 +1,140 @@
package gocql
import (
"fmt"
"sync"
"sync/atomic"
)
type ConnPicker interface {
Pick(Token, ExecutableQuery) *Conn
Put(*Conn)
Remove(conn *Conn)
InFlight() int
Size() (int, int)
Close()
// NextShard returns the shardID to connect to.
// nrShard specifies how many shards the host has.
// If nrShards is zero, the caller shouldn't use shard-aware port.
NextShard() (shardID, nrShards int)
}
type defaultConnPicker struct {
conns []*Conn
pos uint32
size int
mu sync.RWMutex
}
func newDefaultConnPicker(size int) *defaultConnPicker {
if size <= 0 {
panic(fmt.Sprintf("invalid pool size %d", size))
}
return &defaultConnPicker{
size: size,
}
}
func (p *defaultConnPicker) Remove(conn *Conn) {
p.mu.Lock()
defer p.mu.Unlock()
for i, candidate := range p.conns {
if candidate == conn {
last := len(p.conns) - 1
p.conns[i], p.conns = p.conns[last], p.conns[:last]
break
}
}
}
func (p *defaultConnPicker) Close() {
p.mu.Lock()
defer p.mu.Unlock()
conns := p.conns
p.conns = nil
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}
func (p *defaultConnPicker) InFlight() int {
size := len(p.conns)
return size
}
func (p *defaultConnPicker) Size() (int, int) {
size := len(p.conns)
return size, p.size - size
}
func (p *defaultConnPicker) Pick(Token, ExecutableQuery) *Conn {
pos := int(atomic.AddUint32(&p.pos, 1) - 1)
size := len(p.conns)
var (
leastBusyConn *Conn
streamsAvailable int
)
// find the conn which has the most available streams, this is racy
for i := 0; i < size; i++ {
conn := p.conns[(pos+i)%size]
if conn == nil {
continue
}
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
return leastBusyConn
}
func (p *defaultConnPicker) Put(conn *Conn) {
p.mu.Lock()
p.conns = append(p.conns, conn)
p.mu.Unlock()
}
func (*defaultConnPicker) NextShard() (shardID, nrShards int) {
return 0, 0
}
// nopConnPicker is a no-operation implementation of ConnPicker, it's used when
// hostConnPool is created to allow deferring creation of the actual ConnPicker
// to the point where we have first connection.
type nopConnPicker struct{}
func (nopConnPicker) Pick(Token, ExecutableQuery) *Conn {
return nil
}
func (nopConnPicker) Put(*Conn) {
}
func (nopConnPicker) Remove(conn *Conn) {
}
func (nopConnPicker) InFlight() int {
return 0
}
func (nopConnPicker) Size() (int, int) {
// Return 1 to make hostConnPool to try to establish a connection.
// When first connection is established hostConnPool replaces nopConnPicker
// with a different ConnPicker implementation.
return 0, 1
}
func (nopConnPicker) Close() {
}
func (nopConnPicker) NextShard() (shardID, nrShards int) {
return 0, 0
}

517
vendor/github.com/gocql/gocql/control.go generated vendored Normal file
View File

@@ -0,0 +1,517 @@
package gocql
import (
"context"
crand "crypto/rand"
"errors"
"fmt"
"math/rand"
"net"
"os"
"regexp"
"strconv"
"sync"
"sync/atomic"
"time"
)
var (
randr *rand.Rand
mutRandr sync.Mutex
)
func init() {
b := make([]byte, 4)
if _, err := crand.Read(b); err != nil {
panic(fmt.Sprintf("unable to seed random number generator: %v", err))
}
randr = rand.New(rand.NewSource(int64(readInt(b))))
}
const (
controlConnStarting = 0
controlConnStarted = 1
controlConnClosing = -1
)
type controlConnection interface {
getConn() *connHost
awaitSchemaAgreement() error
query(statement string, values ...interface{}) (iter *Iter)
discoverProtocol(hosts []*HostInfo) (int, error)
connect(hosts []*HostInfo) error
close()
getSession() *Session
}
// Ensure that the atomic variable is aligned to a 64bit boundary
// so that atomic operations can be applied on 32bit architectures.
type controlConn struct {
state int32
reconnecting int32
session *Session
conn atomic.Value
retry RetryPolicy
quit chan struct{}
}
func (c *controlConn) getSession() *Session {
return c.session
}
func createControlConn(session *Session) *controlConn {
control := &controlConn{
session: session,
quit: make(chan struct{}),
retry: &SimpleRetryPolicy{NumRetries: 3},
}
control.conn.Store((*connHost)(nil))
return control
}
func (c *controlConn) heartBeat() {
if !atomic.CompareAndSwapInt32(&c.state, controlConnStarting, controlConnStarted) {
return
}
sleepTime := 1 * time.Second
timer := time.NewTimer(sleepTime)
defer timer.Stop()
for {
timer.Reset(sleepTime)
select {
case <-c.quit:
return
case <-timer.C:
}
resp, err := c.writeFrame(&writeOptionsFrame{})
if err != nil {
goto reconn
}
switch resp.(type) {
case *supportedFrame:
// Everything ok
sleepTime = 30 * time.Second
continue
case error:
goto reconn
default:
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
}
reconn:
// try to connect a bit faster
sleepTime = 1 * time.Second
c.reconnect()
continue
}
}
var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
var port int
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
host = addr
port = defaultPort
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}
var hosts []*HostInfo
// Check if host is a literal IP address
if ip := net.ParseIP(host); ip != nil {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
return hosts, nil
}
// Look up host in DNS
ips, err := LookupIP(host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
}
// Filter to v4 addresses if any present
if hostLookupPreferV4 {
var preferredIPs []net.IP
for _, v := range ips {
if v4 := v.To4(); v4 != nil {
preferredIPs = append(preferredIPs, v4)
}
}
if len(preferredIPs) != 0 {
ips = preferredIPs
}
}
for _, ip := range ips {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
}
return hosts, nil
}
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
shuffled := make([]*HostInfo, len(hosts))
copy(shuffled, hosts)
mutRandr.Lock()
randr.Shuffle(len(hosts), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
mutRandr.Unlock()
return shuffled
}
// this is going to be version dependant and a nightmare to maintain :(
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
func parseProtocolFromError(err error) int {
// I really wish this had the actual info in the error frame...
matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
if len(matches) != 1 || len(matches[0]) != 2 {
if verr, ok := err.(*protocolError); ok {
return int(verr.frame.Header().version.version())
}
return 0
}
max, err := strconv.Atoi(matches[0][1])
if err != nil {
return 0
}
return max
}
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
hosts = shuffleHosts(hosts)
connCfg := *c.session.connCfg
connCfg.ProtoVersion = 4 // TODO: define maxProtocol
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected to a
// host successfully which means our attempted protocol version worked
if !closed {
c.Close()
}
})
var err error
for _, host := range hosts {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
if conn != nil {
conn.Close()
}
if err == nil {
return connCfg.ProtoVersion, nil
}
if proto := parseProtocolFromError(err); proto > 0 {
return proto, nil
}
}
return 0, err
}
func (c *controlConn) connect(hosts []*HostInfo) error {
if len(hosts) == 0 {
return errors.New("control: no endpoints specified")
}
// shuffle endpoints so not all drivers will connect to the same initial
// node.
hosts = shuffleHosts(hosts)
cfg := *c.session.connCfg
cfg.disableCoalesce = true
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
if err != nil {
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
if conn == nil {
return fmt.Errorf("unable to connect to initial hosts: %v", err)
}
// we could fetch the initial ring here and update initial host data. So that
// when we return from here we have a ring topology ready to go.
go c.heartBeat()
return nil
}
type connHost struct {
conn ConnInterface
host *HostInfo
}
func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystem(context.TODO(), qrySystemLocal)
defaultPort := 9042
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
defaultPort = tcpAddr.Port
}
host, err := hostInfoFromIter(iter, conn.host.connectAddress, defaultPort, c.session.cfg.translateAddressPort)
if err != nil {
return err
}
host = c.session.hostSource.addOrUpdate(host)
if c.session.cfg.filterHost(host) {
return fmt.Errorf("host was filtered: %v", host.ConnectAddress())
}
if err := c.registerEvents(conn); err != nil {
return fmt.Errorf("register events: %v", err)
}
ch := &connHost{
conn: conn,
host: host,
}
c.conn.Store(ch)
if c.session.initialized() {
// We connected to control conn, so add the connect the host in pool as well.
// Notify session we can start trying to connect to the node.
// We can't start the fill before the session is initialized, otherwise the fill would interfere
// with the fill called by Session.init. Session.init needs to wait for its fill to finish and that
// would return immediately if we started the fill here.
// TODO(martin-sucha): Trigger pool refill for all hosts, like in reconnectDownedHosts?
go c.session.startPoolFill(host)
}
return nil
}
func (c *controlConn) registerEvents(conn *Conn) error {
var events []string
if !c.session.cfg.Events.DisableTopologyEvents {
events = append(events, "TOPOLOGY_CHANGE")
}
if !c.session.cfg.Events.DisableNodeStatusEvents {
events = append(events, "STATUS_CHANGE")
}
if !c.session.cfg.Events.DisableSchemaEvents {
events = append(events, "SCHEMA_CHANGE")
}
if len(events) == 0 {
return nil
}
framer, err := conn.exec(context.Background(),
&writeRegisterFrame{
events: events,
}, nil)
if err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
} else if _, ok := frame.(*readyFrame); !ok {
return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
}
return nil
}
func (c *controlConn) reconnect() {
if atomic.LoadInt32(&c.state) == controlConnClosing {
return
}
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
return
}
defer atomic.StoreInt32(&c.reconnecting, 0)
conn, err := c.attemptReconnect()
if conn == nil {
c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err)
return
}
err = c.session.refreshRingNow()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
err = c.session.metadataDescriber.refreshAllSchema()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh the schema: %v\n", err)
}
}
func (c *controlConn) attemptReconnect() (*Conn, error) {
hosts := c.session.hostSource.getHostsList()
hosts = shuffleHosts(hosts)
// keep the old behavior of connecting to the old host first by moving it to
// the front of the slice
ch := c.getConn()
if ch != nil {
for i := range hosts {
if hosts[i].Equal(ch.host) {
hosts[0], hosts[i] = hosts[i], hosts[0]
break
}
}
ch.conn.Close()
}
conn, err := c.attemptReconnectToAnyOfHosts(hosts)
if conn != nil {
return conn, err
}
c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err)
c.session.logger.Printf("gocql: control falling back to initial contact points.\n")
// Fallback to initial contact points, as it may be the case that all known initialHosts
// changed their IPs while keeping the same hostname(s).
initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
if resolvErr != nil {
return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr)
}
return c.attemptReconnectToAnyOfHosts(initialHosts)
}
func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) {
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.connect(c.session.ctx, host, c)
if err != nil {
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
}
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
return conn, err
}
func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
if !closed {
return
}
oldConn := c.getConn()
// If connection has long gone, and not been attempted for awhile,
// it's possible to have oldConn as nil here (#1297).
if oldConn != nil && oldConn.conn != conn {
return
}
c.reconnect()
}
func (c *controlConn) getConn() *connHost {
return c.conn.Load().(*connHost)
}
func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
ch := c.getConn()
if ch == nil {
return nil, errNoControl
}
framer, err := ch.conn.exec(context.Background(), w, nil)
if err != nil {
return nil, err
}
return framer.parseFrame()
}
// query will return nil if the connection is closed or nil
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
for {
ch := c.getConn()
q.conn = ch.conn.(*Conn)
iter = ch.conn.executeQuery(context.TODO(), q)
if gocqlDebug && iter.err != nil {
c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err)
}
q.AddAttempts(1, c.getConn().host)
if iter.err == nil || !c.retry.Attempt(q) {
break
}
}
return
}
func (c *controlConn) awaitSchemaAgreement() error {
ch := c.getConn()
return (&Iter{err: ch.conn.awaitSchemaAgreement(context.TODO())}).err
}
func (c *controlConn) close() {
if atomic.CompareAndSwapInt32(&c.state, controlConnStarted, controlConnClosing) {
c.quit <- struct{}{}
}
ch := c.getConn()
if ch != nil {
ch.conn.Close()
}
}
var errNoControl = errors.New("gocql: no control connection available")

11
vendor/github.com/gocql/gocql/cqltypes.go generated vendored Normal file
View File

@@ -0,0 +1,11 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
type Duration struct {
Months int32
Days int32
Nanoseconds int64
}

View File

@@ -0,0 +1,164 @@
package debounce
import (
"sync"
"time"
)
const (
RingRefreshDebounceTime = 1 * time.Second
)
// debounces requests to call a refresh function (currently used for ring refresh). It also supports triggering a refresh immediately.
type RefreshDebouncer struct {
mu sync.Mutex
stopped bool
broadcaster *errorBroadcaster
interval time.Duration
timer *time.Timer
refreshNowCh chan struct{}
quit chan struct{}
refreshFn func() error
}
func NewRefreshDebouncer(interval time.Duration, refreshFn func() error) *RefreshDebouncer {
d := &RefreshDebouncer{
stopped: false,
broadcaster: nil,
refreshNowCh: make(chan struct{}, 1),
quit: make(chan struct{}),
interval: interval,
timer: time.NewTimer(interval),
refreshFn: refreshFn,
}
d.timer.Stop()
go d.flusher()
return d
}
// debounces a request to call the refresh function
func (d *RefreshDebouncer) Debounce() {
d.mu.Lock()
defer d.mu.Unlock()
if d.stopped {
return
}
d.timer.Reset(d.interval)
}
// requests an immediate refresh which will cancel pending refresh requests
func (d *RefreshDebouncer) RefreshNow() <-chan error {
d.mu.Lock()
defer d.mu.Unlock()
if d.broadcaster == nil {
d.broadcaster = newErrorBroadcaster()
select {
case d.refreshNowCh <- struct{}{}:
default:
// already a refresh pending
}
}
return d.broadcaster.newListener()
}
func (d *RefreshDebouncer) flusher() {
for {
select {
case <-d.refreshNowCh:
case <-d.timer.C:
case <-d.quit:
}
d.mu.Lock()
if d.stopped {
if d.broadcaster != nil {
d.broadcaster.stop()
d.broadcaster = nil
}
d.timer.Stop()
d.mu.Unlock()
return
}
// make sure both request channels are cleared before we refresh
select {
case <-d.refreshNowCh:
default:
}
d.timer.Stop()
select {
case <-d.timer.C:
default:
}
curBroadcaster := d.broadcaster
d.broadcaster = nil
d.mu.Unlock()
err := d.refreshFn()
if curBroadcaster != nil {
curBroadcaster.broadcast(err)
}
}
}
func (d *RefreshDebouncer) Stop() {
d.mu.Lock()
if d.stopped {
d.mu.Unlock()
return
}
d.stopped = true
d.mu.Unlock()
d.quit <- struct{}{} // sync with flusher
close(d.quit)
}
// broadcasts an error to multiple channels (listeners)
type errorBroadcaster struct {
listeners []chan<- error
mu sync.Mutex
}
func newErrorBroadcaster() *errorBroadcaster {
return &errorBroadcaster{
listeners: nil,
mu: sync.Mutex{},
}
}
func (b *errorBroadcaster) newListener() <-chan error {
ch := make(chan error, 1)
b.mu.Lock()
defer b.mu.Unlock()
b.listeners = append(b.listeners, ch)
return ch
}
func (b *errorBroadcaster) broadcast(err error) {
b.mu.Lock()
defer b.mu.Unlock()
curListeners := b.listeners
if len(curListeners) > 0 {
b.listeners = nil
} else {
return
}
for _, listener := range curListeners {
listener <- err
close(listener)
}
}
func (b *errorBroadcaster) stop() {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.listeners) == 0 {
return
}
for _, listener := range b.listeners {
close(listener)
}
b.listeners = nil
}

View File

@@ -0,0 +1,34 @@
package debounce
import (
"sync"
"sync/atomic"
)
// SimpleDebouncer is are tool for queuing immutable functions calls. It provides:
// 1. Blocking simultaneous calls
// 2. If there is no running call and no waiting call, then the current call go through
// 3. If there is running call and no waiting call, then the current call go waiting
// 4. If there is running call and waiting call, then the current call are voided
type SimpleDebouncer struct {
m sync.Mutex
count atomic.Int32
}
// NewSimpleDebouncer creates a new SimpleDebouncer.
func NewSimpleDebouncer() *SimpleDebouncer {
return &SimpleDebouncer{}
}
// Debounce attempts to execute the function if the logic of the SimpleDebouncer allows it.
func (d *SimpleDebouncer) Debounce(fn func()) bool {
if d.count.Add(1) > 2 {
d.count.Add(-1)
return false
}
d.m.Lock()
fn()
d.count.Add(-1)
d.m.Unlock()
return true
}

6
vendor/github.com/gocql/gocql/debug_off.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
//go:build !gocql_debug
// +build !gocql_debug
package gocql
const gocqlDebug = false

6
vendor/github.com/gocql/gocql/debug_on.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
//go:build gocql_debug
// +build gocql_debug
package gocql
const gocqlDebug = true

91
vendor/github.com/gocql/gocql/dial.go generated vendored Normal file
View File

@@ -0,0 +1,91 @@
package gocql
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
)
// HostDialer allows customizing connection to cluster nodes.
type HostDialer interface {
// DialHost establishes a connection to the host.
// The returned connection must be directly usable for CQL protocol,
// specifically DialHost is responsible also for setting up the TLS session if needed.
// DialHost should disable write coalescing if the returned net.Conn does not support writev.
// As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing.
// You can use WrapTLS helper function if you don't need to override the TLS setup.
DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error)
}
// DialedHost contains information about established connection to a host.
type DialedHost struct {
// Conn used to communicate with the server.
Conn net.Conn
// DisableCoalesce disables write coalescing for the Conn.
// If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0.
DisableCoalesce bool
}
// defaultHostDialer dials host in a default way.
type defaultHostDialer struct {
dialer Dialer
tlsConfig *tls.Config
}
func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
connAddr := host.ConnectAddressAndPort()
conn, err := hd.dialer.DialContext(ctx, "tcp", connAddr)
if err != nil {
return nil, err
}
addr := host.HostnameAndPort()
return WrapTLS(ctx, conn, addr, hd.tlsConfig)
}
func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
return tlsConfig
}
// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig.
// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure.
// If the tlsConfig does not have server name set, it is updated based on the default gocql rules.
func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) {
if tlsConfig != nil {
tlsConfig := tlsConfigForAddr(tlsConfig, addr)
tconn := tls.Client(conn, tlsConfig)
if err := tconn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}
return &DialedHost{
Conn: conn,
DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped.
}, nil
}

375
vendor/github.com/gocql/gocql/doc.go generated vendored Normal file
View File

@@ -0,0 +1,375 @@
// Copyright (c) 2012-2015 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gocql implements a fast and robust Cassandra driver for the
// Go programming language.
//
// # Connecting to the cluster
//
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
//
// Port can be specified as part of the address, the above is equivalent to:
//
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
//
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
//
// Then you can customize more options (see ClusterConfig):
//
// cluster.Keyspace = "example"
// cluster.Consistency = gocql.Quorum
// cluster.ProtoVersion = 4
//
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
//
// The driver advertises the module name and version in the STARTUP message, so servers are able to detect the version.
// If you use replace directive in go.mod, the driver will send information about the replacement module instead.
//
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
//
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// # Authentication
//
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
// client response pairs. The details of the exchanged messages depend on the authenticator used.
//
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
//
// PasswordAuthenticator is provided to use for username/password authentication:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.Authenticator = gocql.PasswordAuthenticator{
// Username: "user",
// Password: "password"
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// By default, PasswordAuthenticator will attempt to authenticate regardless of what implementation the server returns
// in its AUTHENTICATE message as its authenticator, (e.g. org.apache.cassandra.auth.PasswordAuthenticator). If you
// wish to restrict this you may use PasswordAuthenticator.AllowedAuthenticators:
//
// cluster.Authenticator = gocql.PasswordAuthenticator {
// Username: "user",
// Password: "password"
// AllowedAuthenticators: []string{"org.apache.cassandra.auth.PasswordAuthenticator"},
// }
//
// # Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
//
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
// There are also helpers to load keys/certificates from files.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
//
// For example:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.SslOpts = &gocql.SslOptions{
// EnableHostVerification: true,
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// # Data-center awareness and query routing
//
// To route queries to local DC first, use DCAwareRoundRobinPolicy. For example, if the datacenter you
// want to primarily connect is called dc1 (as configured in the database):
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1")
//
// The driver can route queries to nodes that hold data replicas based on partition key (preferring local DC).
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1"))
//
// Note that TokenAwareHostPolicy can take options such as gocql.ShuffleReplicas and gocql.NonLocalReplicasFallback.
//
// We recommend running with a token aware host policy in production for maximum performance.
//
// The driver can only use token-aware routing for queries where all partition key columns are query parameters.
// For example, instead of
//
// session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def")
//
// use
//
// session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def")
//
// # Rack-level awareness
//
// The DCAwareRoundRobinPolicy can be replaced with RackAwareRoundRobinPolicy, which takes two parameters, datacenter and rack.
//
// Instead of dividing hosts with two tiers (local datacenter and remote datacenters) it divides hosts into three
// (the local rack, the rest of the local datacenter, and everything else).
//
// RackAwareRoundRobinPolicy can be combined with TokenAwareHostPolicy in the same way as DCAwareRoundRobinPolicy.
//
// # Executing queries
//
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
// modified after starting execution of the query.
//
// To execute a query without reading results, use Query.Exec:
//
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
//
// Single row can be read by calling Query.Scan:
//
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
//
// Multiple rows can be read using Iter.Scanner:
//
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
// "me").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var (
// id gocql.UUID
// text string
// )
// err = scanner.Scan(&id, &text)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Println("Tweet:", id, text)
// }
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
// if err := scanner.Err(); err != nil {
// log.Fatal(err)
// }
//
// See Example for complete example.
//
// # Prepared statements
//
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
// of prepared statements.
// CQL protocol does not support preparing other query types.
//
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
// This will cause the database to ignore writing the column.
// The main advantage is the ability to keep the same prepared statement even when you don't
// want to update some fields, where before you needed to make another prepared statement.
//
// # Executing multiple queries concurrently
//
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
// are executed asynchronously at the protocol level.
//
// results := make(chan error, 2)
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 1").Exec()
// }()
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 2").Exec()
// }()
//
// # Nulls
//
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
// column being null and empty string, you can unmarshal into *string variable instead of string.
//
// var text *string
// err := scanner.Scan(&text)
// if err != nil {
// // handle error
// }
// if text != nil {
// // not null
// }
// else {
// // null
// }
//
// See Example_nulls for full example.
//
// # Reusing slices
//
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
// memory structures.
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// var myInts []int
// for scanner.Next() {
// // This scan reuses backing store of myInts for each row.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
// slice variable within the scanner loop:
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var myInts []int
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// # Paging
//
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
// Query.PageSize, and Query.Prefetch.
//
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
// it contains data from primary keys.
//
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
//
// The driver does not check whether the paging state is from the same protocol version/statement.
// You might want to validate yourself as this could be a problem if you store paging state externally.
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
//
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
//
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
// rely on the page having the exact count of items.
//
// See Example_paging for an example of manual paging.
//
// # Dynamic list of columns
//
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
//
// See Example_dynamicColumns.
//
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
// overhead to ensure this property.
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
// A counter batch can only contain statements to update counters.
//
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
// involve only a single partition).
// Multi-partition batch needs to be split by the coordinator node and re-sent to
// correct nodes.
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
// additional network hop.
//
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
// There are differences how those are executed.
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
// Session.ExecuteBatch prepares individual statements in the batch.
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
//
// See Example_batch for an example.
//
// # Lightweight transactions
//
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
//
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
// Session.MapExecuteBatchCAS.
//
// # Retries and speculative execution
//
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
// execution.
//
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
// If the query is LWT and the configured RetryPolicy additionally implements LWTRetryPolicy
// interface, then the policy will be cast to LWTRetryPolicy and used this way.
//
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
// is still executing. The two parallel executions of the query race to return a result, the first received result will
// be returned.
//
// # User-defined types
//
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
//
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
//
// type MyUDT struct {
// FieldA int32 `cql:"a"`
// FieldB string `cql:"b"`
// }
//
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
//
// # Metrics and tracing
//
// It is possible to provide observer implementations that could be used to gather metrics:
//
// - QueryObserver for monitoring individual queries.
// - BatchObserver for monitoring batch queries.
// - ConnectObserver for monitoring new connections from the driver to the database.
// - FrameHeaderObserver for monitoring individual protocol frames.
//
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
// the session ID that the database used to store the trace information in system_traces.sessions and
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
// so this feature should not be used on production systems with very high load unless you know what you are doing.
// There is also a new implementation of Tracer - TracerEnhanced, that is intended to be more reliable and convinient to use.
// It has a funcionality to check if trace is ready to be extracted and only actually gets it if requested which makes
// the impact on a performance smaller.
package gocql // import "github.com/gocql/gocql"

90
vendor/github.com/gocql/gocql/docker-compose.yml generated vendored Normal file
View File

@@ -0,0 +1,90 @@
version: "3.7"
services:
node_1:
image: ${SCYLLA_IMAGE}
privileged: true
command: |
--smp 2
--memory 768M
--seeds 192.168.100.11
--overprovisioned 1
--experimental-features udf
--enable-user-defined-functions true
networks:
public:
ipv4_address: 192.168.100.11
volumes:
- /tmp/scylla:/var/lib/scylla/
- type: bind
source: ./testdata/config/scylla.yaml
target: /etc/scylla/scylla.yaml
- type: bind
source: ./testdata/pki/ca.crt
target: /etc/scylla/ca.crt
- type: bind
source: ./testdata/pki/cassandra.crt
target: /etc/scylla/db.crt
- type: bind
source: ./testdata/pki/cassandra.key
target: /etc/scylla/db.key
healthcheck:
test: [ "CMD", "cqlsh", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
node_2:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.12
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.12", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
node_3:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.13
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.13", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
depends_on:
node_2:
condition: service_healthy
node_4:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.14
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.14", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
depends_on:
node_3:
condition: service_healthy
networks:
public:
driver: bridge
ipam:
driver: default
config:
- subnet: 192.168.100.0/24

227
vendor/github.com/gocql/gocql/errors.go generated vendored Normal file
View File

@@ -0,0 +1,227 @@
package gocql
import "fmt"
// See CQL Binary Protocol v5, section 8 for more details.
// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec
const (
// ErrCodeServer indicates unexpected error on server-side.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247
ErrCodeServer = 0x0000
// ErrCodeProtocol indicates a protocol violation by some client message.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250
ErrCodeProtocol = 0x000A
// ErrCodeCredentials indicates missing required authentication.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254
ErrCodeCredentials = 0x0100
// ErrCodeUnavailable indicates unavailable error.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265
ErrCodeUnavailable = 0x1000
// ErrCodeOverloaded returned in case of request on overloaded node coordinator.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267
ErrCodeOverloaded = 0x1001
// ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269
ErrCodeBootstrapping = 0x1002
// ErrCodeTruncate indicates truncation exception.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270
ErrCodeTruncate = 0x1003
// ErrCodeWriteTimeout returned in case of timeout during the request write.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304
ErrCodeWriteTimeout = 0x1100
// ErrCodeReadTimeout returned in case of timeout during the request read.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321
ErrCodeReadTimeout = 0x1200
// ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340
ErrCodeReadFailure = 0x1300
// ErrCodeFunctionFailure indicates an error in user-defined function.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347
ErrCodeFunctionFailure = 0x1400
// ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385
ErrCodeWriteFailure = 0x1500
// ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386
ErrCodeCDCWriteFailure = 0x1600
// ErrCodeCASWriteUnknown indicates only partially completed CAS operation.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
ErrCodeCASWriteUnknown = 0x1700
// ErrCodeSyntax indicates the syntax error in the query.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399
ErrCodeSyntax = 0x2000
// ErrCodeUnauthorized indicates access rights violation by user on performed operation.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401
ErrCodeUnauthorized = 0x2100
// ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402
ErrCodeInvalid = 0x2200
// ErrCodeConfig indicates the configuration error.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403
ErrCodeConfig = 0x2300
// ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413
ErrCodeAlreadyExists = 0x2400
// ErrCodeUnprepared returned from the host for prepared statement which is unknown.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417
ErrCodeUnprepared = 0x2500
)
type RequestError interface {
Code() int
Message() string
Error() string
}
type errorFrame struct {
frameHeader
code int
message string
}
func (e errorFrame) Code() int {
return e.code
}
func (e errorFrame) Message() string {
return e.message
}
func (e errorFrame) Error() string {
return e.Message()
}
func (e errorFrame) String() string {
return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message)
}
type RequestErrUnavailable struct {
errorFrame
Consistency Consistency
Required int
Alive int
}
func (e *RequestErrUnavailable) String() string {
return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
}
type ErrorMap map[string]uint16
type RequestErrWriteTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
WriteType string
}
type RequestErrWriteFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
WriteType string
ErrorMap ErrorMap
}
type RequestErrCDCWriteFailure struct {
errorFrame
}
type RequestErrReadTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
DataPresent byte
}
type RequestErrAlreadyExists struct {
errorFrame
Keyspace string
Table string
}
type RequestErrUnprepared struct {
errorFrame
StatementId []byte
}
type RequestErrReadFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
DataPresent bool
ErrorMap ErrorMap
}
type RequestErrFunctionFailure struct {
errorFrame
Keyspace string
Function string
ArgTypes []string
}
// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
type RequestErrCASWriteUnknown struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
}
type UnknownServerError struct {
errorFrame
}
type OpType uint8
const (
OpTypeRead OpType = 0
OpTypeWrite OpType = 1
)
type RequestErrRateLimitReached struct {
errorFrame
OpType OpType
RejectedByCoordinator bool
}
func (e *RequestErrRateLimitReached) String() string {
var opType string
if e.OpType == OpTypeRead {
opType = "Read"
} else if e.OpType == OpTypeWrite {
opType = "Write"
} else {
opType = "Other"
}
return fmt.Sprintf("[request_error_rate_limit_reached OpType=%s RejectedByCoordinator=%t]", opType, e.RejectedByCoordinator)
}

256
vendor/github.com/gocql/gocql/events.go generated vendored Normal file
View File

@@ -0,0 +1,256 @@
package gocql
import (
"net"
"sync"
"time"
)
type eventDebouncer struct {
name string
timer *time.Timer
mu sync.Mutex
events []frame
callback func([]frame)
quit chan struct{}
logger StdLogger
}
func newEventDebouncer(name string, eventHandler func([]frame), logger StdLogger) *eventDebouncer {
e := &eventDebouncer{
name: name,
quit: make(chan struct{}),
timer: time.NewTimer(eventDebounceTime),
callback: eventHandler,
logger: logger,
}
e.timer.Stop()
go e.flusher()
return e
}
func (e *eventDebouncer) stop() {
e.quit <- struct{}{} // sync with flusher
close(e.quit)
}
func (e *eventDebouncer) flusher() {
for {
select {
case <-e.timer.C:
e.mu.Lock()
e.flush()
e.mu.Unlock()
case <-e.quit:
return
}
}
}
const (
eventBufferSize = 1000
eventDebounceTime = 1 * time.Second
)
// flush must be called with mu locked
func (e *eventDebouncer) flush() {
if len(e.events) == 0 {
return
}
// if the flush interval is faster than the callback then we will end up calling
// the callback multiple times, probably a bad idea. In this case we could drop
// frames?
go e.callback(e.events)
e.events = make([]frame, 0, eventBufferSize)
}
func (e *eventDebouncer) debounce(frame frame) {
e.mu.Lock()
e.timer.Reset(eventDebounceTime)
// TODO: probably need a warning to track if this threshold is too low
if len(e.events) < eventBufferSize {
e.events = append(e.events, frame)
} else {
e.logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
}
e.mu.Unlock()
}
func (s *Session) handleEvent(framer *framer) {
frame, err := framer.parseFrame()
if err != nil {
s.logger.Printf("gocql: unable to parse event frame: %v\n", err)
return
}
if gocqlDebug {
s.logger.Printf("gocql: handling frame: %v\n", frame)
}
switch f := frame.(type) {
case *schemaChangeKeyspace, *schemaChangeFunction,
*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
s.schemaEvents.debounce(frame)
case *topologyChangeEventFrame, *statusChangeEventFrame:
s.nodeEvents.debounce(frame)
default:
s.logger.Printf("gocql: invalid event frame (%T): %v\n", f, f)
}
}
func (s *Session) handleSchemaEvent(frames []frame) {
// TODO: debounce events
for _, frame := range frames {
switch f := frame.(type) {
case *schemaChangeKeyspace:
s.metadataDescriber.clearSchema(f.keyspace)
s.handleKeyspaceChange(f.keyspace, f.change)
case *schemaChangeTable:
s.metadataDescriber.clearSchema(f.keyspace)
s.handleTableChange(f.keyspace, f.object, f.change)
case *schemaChangeAggregate:
s.metadataDescriber.clearSchema(f.keyspace)
case *schemaChangeFunction:
s.metadataDescriber.clearSchema(f.keyspace)
case *schemaChangeType:
s.metadataDescriber.clearSchema(f.keyspace)
}
}
}
func (s *Session) handleKeyspaceChange(keyspace, change string) {
s.control.awaitSchemaAgreement()
if change == "DROPPED" || change == "UPDATED" {
s.metadataDescriber.removeTabletsWithKeyspace(keyspace)
}
s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
}
func (s *Session) handleTableChange(keyspace, table, change string) {
if change == "DROPPED" || change == "UPDATED" {
s.metadataDescriber.removeTabletsWithTable(keyspace, table)
}
}
// handleNodeEvent handles inbound status and topology change events.
//
// Status events are debounced by host IP; only the latest event is processed.
//
// Topology events are debounced by performing a single full topology refresh
// whenever any topology event comes in.
//
// Processing topology change events before status change events ensures
// that a NEW_NODE event is not dropped in favor of a newer UP event (which
// would itself be dropped/ignored, as the node is not yet known).
func (s *Session) handleNodeEvent(frames []frame) {
type nodeEvent struct {
change string
host net.IP
port int
}
topologyEventReceived := false
// status change events
sEvents := make(map[string]*nodeEvent)
for _, frame := range frames {
switch f := frame.(type) {
case *topologyChangeEventFrame:
topologyEventReceived = true
case *statusChangeEventFrame:
event, ok := sEvents[f.host.String()]
if !ok {
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
sEvents[f.host.String()] = event
}
event.change = f.change
}
}
if topologyEventReceived && !s.cfg.Events.DisableTopologyEvents {
s.debounceRingRefresh()
}
for _, f := range sEvents {
if gocqlDebug {
s.logger.Printf("gocql: dispatching status change event: %+v\n", f)
}
// ignore events we received if they were disabled
// see https://github.com/gocql/gocql/issues/1591
switch f.change {
case "UP":
if !s.cfg.Events.DisableNodeStatusEvents {
s.handleNodeUp(f.host, f.port)
}
case "DOWN":
if !s.cfg.Events.DisableNodeStatusEvents {
s.handleNodeDown(f.host, f.port)
}
}
}
}
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
}
host, ok := s.hostSource.getHostByIP(eventIp.String())
if !ok {
s.debounceRingRefresh()
return
}
if s.cfg.filterHost(host) {
return
}
if d := host.Version().nodeUpDelay(); d > 0 {
time.Sleep(d)
}
s.startPoolFill(host)
}
func (s *Session) startPoolFill(host *HostInfo) {
// we let the pool call handleNodeConnected to change the host state
s.pool.addHost(host)
s.policy.AddHost(host)
}
func (s *Session) handleNodeConnected(host *HostInfo) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeConnected: %s:%d\n", host.ConnectAddress(), host.Port())
}
host.setState(NodeUp)
if !s.cfg.filterHost(host) {
s.policy.HostUp(host)
}
}
func (s *Session) handleNodeDown(ip net.IP, port int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
}
host, ok := s.hostSource.getHostByIP(ip.String())
if ok {
host.setState(NodeDown)
if s.cfg.filterHost(host) {
return
}
s.policy.HostDown(host)
hostID := host.HostID()
s.pool.removeHost(hostID)
}
}

110
vendor/github.com/gocql/gocql/exec.go generated vendored Normal file
View File

@@ -0,0 +1,110 @@
package gocql
import (
"fmt"
)
// SingleHostQueryExecutor allows to quickly execute diagnostic queries while
// connected to only a single node.
// The executor opens only a single connection to a node and does not use
// connection pools.
// Consistency level used is ONE.
// Retry policy is applied, attempts are visible in query metrics but query
// observer is not notified.
type SingleHostQueryExecutor struct {
session *Session
control *controlConn
}
// Exec executes the query without returning any rows.
func (e SingleHostQueryExecutor) Exec(stmt string, values ...interface{}) error {
return e.control.query(stmt, values...).Close()
}
// Iter executes the query and returns an iterator capable of iterating
// over all results.
func (e SingleHostQueryExecutor) Iter(stmt string, values ...interface{}) *Iter {
return e.control.query(stmt, values...)
}
func (e SingleHostQueryExecutor) Close() {
if e.control != nil {
e.control.close()
}
if e.session != nil {
e.session.Close()
}
}
// NewSingleHostQueryExecutor creates a SingleHostQueryExecutor by connecting
// to one of the hosts specified in the ClusterConfig.
// If ProtoVersion is not specified version 4 is used.
// Caller is responsible for closing the executor after use.
func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor, err error) {
// Check that hosts in the ClusterConfig is not empty
if len(cfg.Hosts) < 1 {
err = ErrNoHosts
return
}
c := *cfg
// If protocol version not set assume 4 and skip discovery
if c.ProtoVersion == 0 {
c.ProtoVersion = 4
}
// Close in case of error
defer func() {
if err != nil {
e.Close()
}
}()
// Create uninitialised session
c.disableInit = true
if e.session, err = NewSession(c); err != nil {
err = fmt.Errorf("new session: %w", err)
return
}
var hosts []*HostInfo
if hosts, err = addrsToHosts(c.Hosts, c.Port, c.Logger); err != nil {
err = fmt.Errorf("addrs to hosts: %w", err)
return
}
// Create control connection to one of the hosts
e.control = createControlConn(e.session)
// shuffle endpoints so not all drivers will connect to the same initial
// node.
hosts = shuffleHosts(hosts)
conncfg := *e.control.session.connCfg
conncfg.disableCoalesce = true
var conn *Conn
for _, host := range hosts {
conn, err = e.control.session.dial(e.control.session.ctx, host, &conncfg, e.control)
if err != nil {
e.control.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = e.control.setupConn(conn)
if err == nil {
break
}
e.control.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
if conn == nil {
err = fmt.Errorf("setup: %w", err)
return
}
return
}

57
vendor/github.com/gocql/gocql/filters.go generated vendored Normal file
View File

@@ -0,0 +1,57 @@
package gocql
import "fmt"
// HostFilter interface is used when a host is discovered via server sent events.
type HostFilter interface {
// Called when a new host is discovered, returning true will cause the host
// to be added to the pools.
Accept(host *HostInfo) bool
}
// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
type HostFilterFunc func(host *HostInfo) bool
func (fn HostFilterFunc) Accept(host *HostInfo) bool {
return fn(host)
}
// AcceptAllFilter will accept all hosts
func AcceptAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return true
})
}
func DenyAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return false
})
}
// DataCentreHostFilter filters all hosts such that they are in the same data centre
// as the supplied data centre.
func DataCentreHostFilter(dataCentre string) HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return host.DataCenter() == dataCentre
})
}
// WhiteListHostFilter filters incoming hosts by checking that their address is
// in the initial hosts whitelist.
func WhiteListHostFilter(hosts ...string) HostFilter {
hostInfos, err := addrsToHosts(hosts, 9042, nopLogger{})
if err != nil {
// dont want to panic here, but rather not break the API
panic(fmt.Errorf("unable to lookup host info from address: %v", err))
}
m := make(map[string]bool, len(hostInfos))
for _, host := range hostInfos {
m[host.ConnectAddress().String()] = true
}
return HostFilterFunc(func(host *HostInfo) bool {
return m[host.ConnectAddress().String()]
})
}

2119
vendor/github.com/gocql/gocql/frame.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

34
vendor/github.com/gocql/gocql/fuzz.go generated vendored Normal file
View File

@@ -0,0 +1,34 @@
//go:build gofuzz
// +build gofuzz
package gocql
import "bytes"
func Fuzz(data []byte) int {
var bw bytes.Buffer
r := bytes.NewReader(data)
head, err := readHeader(r, make([]byte, 9))
if err != nil {
return 0
}
framer := newFramer(r, &bw, nil, byte(head.version))
err = framer.readFrame(&head)
if err != nil {
return 0
}
frame, err := framer.parseFrame()
if err != nil {
return 0
}
if frame != nil {
return 1
}
return 2
}

448
vendor/github.com/gocql/gocql/helpers.go generated vendored Normal file
View File

@@ -0,0 +1,448 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"fmt"
"math/big"
"net"
"reflect"
"strings"
"time"
"gopkg.in/inf.v0"
)
type RowData struct {
Columns []string
Values []interface{}
}
func goType(t TypeInfo) (reflect.Type, error) {
switch t.Type() {
case TypeVarchar, TypeAscii, TypeInet, TypeText:
return reflect.TypeOf(*new(string)), nil
case TypeBigInt, TypeCounter:
return reflect.TypeOf(*new(int64)), nil
case TypeTime:
return reflect.TypeOf(*new(time.Duration)), nil
case TypeTimestamp:
return reflect.TypeOf(*new(time.Time)), nil
case TypeBlob:
return reflect.TypeOf(*new([]byte)), nil
case TypeBoolean:
return reflect.TypeOf(*new(bool)), nil
case TypeFloat:
return reflect.TypeOf(*new(float32)), nil
case TypeDouble:
return reflect.TypeOf(*new(float64)), nil
case TypeInt:
return reflect.TypeOf(*new(int)), nil
case TypeSmallInt:
return reflect.TypeOf(*new(int16)), nil
case TypeTinyInt:
return reflect.TypeOf(*new(int8)), nil
case TypeDecimal:
return reflect.TypeOf(*new(*inf.Dec)), nil
case TypeUUID, TypeTimeUUID:
return reflect.TypeOf(*new(UUID)), nil
case TypeList, TypeSet:
elemType, err := goType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.SliceOf(elemType), nil
case TypeMap:
keyType, err := goType(t.(CollectionType).Key)
if err != nil {
return nil, err
}
valueType, err := goType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.MapOf(keyType, valueType), nil
case TypeVarint:
return reflect.TypeOf(*new(*big.Int)), nil
case TypeTuple:
// what can we do here? all there is to do is to make a list of interface{}
tuple := t.(TupleTypeInfo)
return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil
case TypeUDT:
return reflect.TypeOf(make(map[string]interface{})), nil
case TypeDate:
return reflect.TypeOf(*new(time.Time)), nil
case TypeDuration:
return reflect.TypeOf(*new(Duration)), nil
default:
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
}
}
func dereference(i interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(i)).Interface()
}
func getCassandraBaseType(name string) Type {
switch name {
case "ascii":
return TypeAscii
case "bigint":
return TypeBigInt
case "blob":
return TypeBlob
case "boolean":
return TypeBoolean
case "counter":
return TypeCounter
case "date":
return TypeDate
case "decimal":
return TypeDecimal
case "double":
return TypeDouble
case "duration":
return TypeDuration
case "float":
return TypeFloat
case "int":
return TypeInt
case "smallint":
return TypeSmallInt
case "tinyint":
return TypeTinyInt
case "time":
return TypeTime
case "timestamp":
return TypeTimestamp
case "uuid":
return TypeUUID
case "varchar":
return TypeVarchar
case "text":
return TypeText
case "varint":
return TypeVarint
case "timeuuid":
return TypeTimeUUID
case "inet":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
default:
return TypeCustom
}
}
func getCassandraType(name string, logger StdLogger) TypeInfo {
if strings.HasPrefix(name, "frozen<") {
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger)
} else if strings.HasPrefix(name, "set<") {
return CollectionType{
NativeType: NativeType{typ: TypeSet},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger),
}
} else if strings.HasPrefix(name, "list<") {
return CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger),
}
} else if strings.HasPrefix(name, "map<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
if len(names) != 2 {
logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
return NativeType{
typ: TypeCustom,
}
}
return CollectionType{
NativeType: NativeType{typ: TypeMap},
Key: getCassandraType(names[0], logger),
Elem: getCassandraType(names[1], logger),
}
} else if strings.HasPrefix(name, "tuple<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
types := make([]TypeInfo, len(names))
for i, name := range names {
types[i] = getCassandraType(name, logger)
}
return TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},
Elems: types,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
}
}
}
func splitCompositeTypes(name string) []string {
if !strings.Contains(name, "<") {
return strings.Split(name, ", ")
}
var parts []string
lessCount := 0
segment := ""
for _, char := range name {
if char == ',' && lessCount == 0 {
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
segment = ""
continue
}
segment += string(char)
if char == '<' {
lessCount++
} else if char == '>' {
lessCount--
}
}
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
return parts
}
func apacheToCassandraType(t string) string {
t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
t = strings.Replace(t, "(", "<", -1)
t = strings.Replace(t, ")", ">", -1)
types := strings.FieldsFunc(t, func(r rune) bool {
return r == '<' || r == '>' || r == ','
})
for _, typ := range types {
t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
}
// This is done so it exactly matches what Cassandra returns
return strings.Replace(t, ",", ", ", -1)
}
func getApacheCassandraType(class string) Type {
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
case "AsciiType":
return TypeAscii
case "LongType":
return TypeBigInt
case "BytesType":
return TypeBlob
case "BooleanType":
return TypeBoolean
case "CounterColumnType":
return TypeCounter
case "DecimalType":
return TypeDecimal
case "DoubleType":
return TypeDouble
case "FloatType":
return TypeFloat
case "Int32Type":
return TypeInt
case "ShortType":
return TypeSmallInt
case "ByteType":
return TypeTinyInt
case "TimeType":
return TypeTime
case "DateType", "TimestampType":
return TypeTimestamp
case "UUIDType", "LexicalUUIDType":
return TypeUUID
case "UTF8Type":
return TypeVarchar
case "IntegerType":
return TypeVarint
case "TimeUUIDType":
return TypeTimeUUID
case "InetAddressType":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
case "DurationType":
return TypeDuration
default:
return TypeCustom
}
}
func (r *RowData) rowMap(m map[string]interface{}) {
for i, column := range r.Columns {
val := dereference(r.Values[i])
if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
reflect.Copy(valCopy, valVal)
m[column] = valCopy.Interface()
} else {
m[column] = val
}
}
}
// TupeColumnName will return the column name of a tuple value in a column named
// c at index n. It should be used if a specific element within a tuple is needed
// to be extracted from a map returned from SliceMap or MapScan.
func TupleColumnName(c string, n int) string {
return fmt.Sprintf("%s[%d]", c, n)
}
func (iter *Iter) RowData() (RowData, error) {
if iter.err != nil {
return RowData{}, iter.err
}
columns := make([]string, 0, len(iter.Columns()))
values := make([]interface{}, 0, len(iter.Columns()))
for _, column := range iter.Columns() {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val, err := column.TypeInfo.NewWithError()
if err != nil {
return RowData{}, err
}
columns = append(columns, column.Name)
values = append(values, val)
} else {
for i, elem := range c.Elems {
columns = append(columns, TupleColumnName(column.Name, i))
val, err := elem.NewWithError()
if err != nil {
return RowData{}, err
}
values = append(values, val)
}
}
}
rowData := RowData{
Columns: columns,
Values: values,
}
return rowData, nil
}
// TODO(zariel): is it worth exporting this?
func (iter *Iter) rowMap() (map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
rowData, _ := iter.RowData()
iter.Scan(rowData.Values...)
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
return m, nil
}
// SliceMap is a helper function to make the API easier to use
// returns the data from the query in the form of []map[string]interface{}
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
dataToReturn := make([]map[string]interface{}, 0)
for iter.Scan(rowData.Values...) {
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
dataToReturn = append(dataToReturn, m)
}
if iter.err != nil {
return nil, iter.err
}
return dataToReturn, nil
}
// MapScan takes a map[string]interface{} and populates it with a row
// that is returned from cassandra.
//
// Each call to MapScan() must be called with a new map object.
// During the call to MapScan() any pointers in the existing map
// are replaced with non pointer types before the call returns
//
// iter := session.Query(`SELECT * FROM mytable`).Iter()
// for {
// // New map each iteration
// row := make(map[string]interface{})
// if !iter.MapScan(row) {
// break
// }
// // Do things with row
// if fullname, ok := row["fullname"]; ok {
// fmt.Printf("Full Name: %s\n", fullname)
// }
// }
//
// You can also pass pointers in the map before each call
//
// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
// var address net.IP
// var age int
// iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
// for {
// // New map each iteration
// row := map[string]interface{}{
// "fullname": &fullName,
// "age": &age,
// "address": &address,
// }
// if !iter.MapScan(row) {
// break
// }
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
// }
func (iter *Iter) MapScan(m map[string]interface{}) bool {
if iter.err != nil {
return false
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
for i, col := range rowData.Columns {
if dest, ok := m[col]; ok {
rowData.Values[i] = dest
}
}
if iter.Scan(rowData.Values...) {
rowData.rowMap(m)
return true
}
return false
}
func copyBytes(p []byte) []byte {
b := make([]byte, len(p))
copy(b, p)
return b
}
var failDNS = false
func LookupIP(host string) ([]net.IP, error) {
if failDNS {
return nil, &net.DNSError{}
}
return net.LookupIP(host)
}

722
vendor/github.com/gocql/gocql/host_source.go generated vendored Normal file
View File

@@ -0,0 +1,722 @@
package gocql
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
)
var ErrCannotFindHost = errors.New("cannot find host")
var ErrHostAlreadyExists = errors.New("host already exists")
type nodeState int32
func (n nodeState) String() string {
if n == NodeUp {
return "UP"
} else if n == NodeDown {
return "DOWN"
}
return fmt.Sprintf("UNKNOWN_%d", n)
}
const (
NodeUp nodeState = iota
NodeDown
)
type cassVersion struct {
Major, Minor, Patch int
}
func (c *cassVersion) Set(v string) error {
if v == "" {
return nil
}
return c.UnmarshalCQL(nil, []byte(v))
}
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
return c.unmarshal(data)
}
func (c *cassVersion) unmarshal(data []byte) error {
version := strings.TrimSuffix(string(data), "-SNAPSHOT")
version = strings.TrimPrefix(version, "v")
v := strings.Split(version, ".")
if len(v) < 2 {
return fmt.Errorf("invalid version string: %s", data)
}
var err error
c.Major, err = strconv.Atoi(v[0])
if err != nil {
return fmt.Errorf("invalid major version %v: %v", v[0], err)
}
c.Minor, err = strconv.Atoi(v[1])
if err != nil {
return fmt.Errorf("invalid minor version %v: %v", v[1], err)
}
if len(v) > 2 {
c.Patch, err = strconv.Atoi(v[2])
if err != nil {
return fmt.Errorf("invalid patch version %v: %v", v[2], err)
}
}
return nil
}
func (c cassVersion) Before(major, minor, patch int) bool {
// We're comparing us (cassVersion) with the provided version (major, minor, patch)
// We return true if our version is lower (comes before) than the provided one.
if c.Major < major {
return true
} else if c.Major == major {
if c.Minor < minor {
return true
} else if c.Minor == minor && c.Patch < patch {
return true
}
}
return false
}
func (c cassVersion) AtLeast(major, minor, patch int) bool {
return !c.Before(major, minor, patch)
}
func (c cassVersion) String() string {
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
}
func (c cassVersion) nodeUpDelay() time.Duration {
if c.Major >= 2 && c.Minor >= 2 {
// CASSANDRA-8236
return 0
}
return 10 * time.Second
}
type HostInfo struct {
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
// that we are thread safe use a mutex to access all fields.
mu sync.RWMutex
hostname string
peer net.IP
broadcastAddress net.IP
listenAddress net.IP
rpcAddress net.IP
preferredIP net.IP
connectAddress net.IP
untranslatedConnectAddress net.IP
port int
dataCenter string
rack string
hostId string
workload string
graph bool
dseVersion string
partitioner string
clusterName string
version cassVersion
state nodeState
schemaVersion string
tokens []string
scyllaShardAwarePort uint16
scyllaShardAwarePortTLS uint16
}
func (h *HostInfo) Equal(host *HostInfo) bool {
if h == host {
// prevent rlock reentry
return true
}
return h.HostID() == host.HostID() && h.ConnectAddressAndPort() == host.ConnectAddressAndPort()
}
func (h *HostInfo) Peer() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peer
}
func (h *HostInfo) invalidConnectAddr() bool {
h.mu.RLock()
defer h.mu.RUnlock()
addr, _ := h.connectAddressLocked()
return !validIpAddr(addr)
}
func validIpAddr(addr net.IP) bool {
return addr != nil && !addr.IsUnspecified()
}
func (h *HostInfo) connectAddressLocked() (net.IP, string) {
if validIpAddr(h.connectAddress) {
return h.connectAddress, "connect_address"
} else if validIpAddr(h.rpcAddress) {
return h.rpcAddress, "rpc_adress"
} else if validIpAddr(h.preferredIP) {
// where does perferred_ip get set?
return h.preferredIP, "preferred_ip"
} else if validIpAddr(h.broadcastAddress) {
return h.broadcastAddress, "broadcast_address"
} else if validIpAddr(h.peer) {
return h.peer, "peer"
}
return net.IPv4zero, "invalid"
}
// nodeToNodeAddress returns address broadcasted between node to nodes.
// It's either `broadcast_address` if host info is read from system.local or `peer` if read from system.peers.
// This IP address is also part of CQL Event emitted on topology/status changes,
// but does not uniquely identify the node in case multiple nodes use the same IP address.
func (h *HostInfo) nodeToNodeAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if validIpAddr(h.broadcastAddress) {
return h.broadcastAddress
} else if validIpAddr(h.peer) {
return h.peer
}
return net.IPv4zero
}
// Returns the address that should be used to connect to the host.
// If you wish to override this, use an AddressTranslator or
// use a HostFilter to SetConnectAddress()
func (h *HostInfo) ConnectAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
return addr
}
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
}
func (h *HostInfo) UntranslatedConnectAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if len(h.untranslatedConnectAddress) != 0 {
return h.untranslatedConnectAddress
}
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
return addr
}
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
}
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
// TODO(zariel): should this not be exported?
h.mu.Lock()
defer h.mu.Unlock()
h.connectAddress = address
return h
}
func (h *HostInfo) BroadcastAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.broadcastAddress
}
func (h *HostInfo) ListenAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.listenAddress
}
func (h *HostInfo) RPCAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.rpcAddress
}
func (h *HostInfo) PreferredIP() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.preferredIP
}
func (h *HostInfo) DataCenter() string {
h.mu.RLock()
dc := h.dataCenter
h.mu.RUnlock()
return dc
}
func (h *HostInfo) Rack() string {
h.mu.RLock()
rack := h.rack
h.mu.RUnlock()
return rack
}
func (h *HostInfo) HostID() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.hostId
}
func (h *HostInfo) SetHostID(hostID string) {
h.mu.Lock()
defer h.mu.Unlock()
h.hostId = hostID
}
func (h *HostInfo) WorkLoad() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.workload
}
func (h *HostInfo) Graph() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.graph
}
func (h *HostInfo) DSEVersion() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.dseVersion
}
func (h *HostInfo) Partitioner() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.partitioner
}
func (h *HostInfo) ClusterName() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.clusterName
}
func (h *HostInfo) Version() cassVersion {
h.mu.RLock()
defer h.mu.RUnlock()
return h.version
}
func (h *HostInfo) State() nodeState {
h.mu.RLock()
defer h.mu.RUnlock()
return h.state
}
func (h *HostInfo) setState(state nodeState) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.state = state
return h
}
func (h *HostInfo) Tokens() []string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.tokens
}
func (h *HostInfo) Port() int {
h.mu.RLock()
defer h.mu.RUnlock()
return h.port
}
func (h *HostInfo) update(from *HostInfo) {
if h == from {
return
}
h.mu.Lock()
defer h.mu.Unlock()
from.mu.RLock()
defer from.mu.RUnlock()
// autogenerated do not update
if h.peer == nil {
h.peer = from.peer
}
if h.broadcastAddress == nil {
h.broadcastAddress = from.broadcastAddress
}
if h.listenAddress == nil {
h.listenAddress = from.listenAddress
}
if h.rpcAddress == nil {
h.rpcAddress = from.rpcAddress
}
if h.preferredIP == nil {
h.preferredIP = from.preferredIP
}
if h.connectAddress == nil {
h.connectAddress = from.connectAddress
}
if h.port == 0 {
h.port = from.port
}
if h.dataCenter == "" {
h.dataCenter = from.dataCenter
}
if h.rack == "" {
h.rack = from.rack
}
if h.hostId == "" {
h.hostId = from.hostId
}
if h.workload == "" {
h.workload = from.workload
}
if h.dseVersion == "" {
h.dseVersion = from.dseVersion
}
if h.partitioner == "" {
h.partitioner = from.partitioner
}
if h.clusterName == "" {
h.clusterName = from.clusterName
}
if h.version == (cassVersion{}) {
h.version = from.version
}
if h.tokens == nil {
h.tokens = from.tokens
}
}
func (h *HostInfo) IsUp() bool {
return h != nil && h.State() == NodeUp
}
func (h *HostInfo) IsBusy(s *Session) bool {
pool, ok := s.pool.getPool(h)
return ok && h != nil && pool.InFlight() >= MAX_IN_FLIGHT_THRESHOLD
}
func (h *HostInfo) HostnameAndPort() string {
// Fast path: in most cases hostname is not empty
var (
hostname string
port int
)
h.mu.RLock()
hostname = h.hostname
port = h.port
h.mu.RUnlock()
if hostname != "" {
return net.JoinHostPort(hostname, strconv.Itoa(port))
}
// Slow path: hostname is empty
h.mu.Lock()
defer h.mu.Unlock()
if h.hostname == "" { // recheck is hostname empty
// if yes - fill it
addr, _ := h.connectAddressLocked()
h.hostname = addr.String()
}
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
}
func (h *HostInfo) Hostname() string {
// Fast path: in most cases hostname is not empty
var hostname string
h.mu.RLock()
hostname = h.hostname
h.mu.RUnlock()
if hostname != "" {
return hostname
}
// Slow path: hostname is empty
h.mu.Lock()
defer h.mu.Unlock()
if h.hostname == "" {
addr, _ := h.connectAddressLocked()
h.hostname = addr.String()
}
return h.hostname
}
func (h *HostInfo) ConnectAddressAndPort() string {
h.mu.Lock()
defer h.mu.Unlock()
addr, _ := h.connectAddressLocked()
return net.JoinHostPort(addr.String(), strconv.Itoa(h.port))
}
func (h *HostInfo) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
connectAddr, source := h.connectAddressLocked()
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
connectAddr, source,
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
}
func (h *HostInfo) setScyllaSupported(s scyllaSupported) {
h.mu.Lock()
defer h.mu.Unlock()
h.scyllaShardAwarePort = s.shardAwarePort
h.scyllaShardAwarePortTLS = s.shardAwarePortSSL
}
// ScyllaShardAwarePort returns the shard aware port of this host.
// Returns zero if the shard aware port is not known.
func (h *HostInfo) ScyllaShardAwarePort() uint16 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.scyllaShardAwarePort
}
// ScyllaShardAwarePortTLS returns the TLS-enabled shard aware port of this host.
// Returns zero if the shard aware port is not known.
func (h *HostInfo) ScyllaShardAwarePortTLS() uint16 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.scyllaShardAwarePortTLS
}
// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
func checkSystemSchema(control controlConnection) (bool, error) {
iter := control.query("SELECT * FROM system_schema.keyspaces" + control.getSession().usingTimeoutClause)
if err := iter.err; err != nil {
if errf, ok := err.(*errorFrame); ok {
if errf.code == ErrCodeSyntax {
return false, nil
}
}
return false, err
}
return true, nil
}
// Given a map that represents a row from either system.local or system.peers
// return as much information as we can in *HostInfo
func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
const assertErrorMsg = "Assertion failed for %s"
var ok bool
// Default to our connected port if the cluster doesn't have port information
for key, value := range row {
switch key {
case "data_center":
host.dataCenter, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "data_center")
}
case "rack":
host.rack, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rack")
}
case "host_id":
hostId, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "host_id")
}
host.hostId = hostId.String()
case "release_version":
version, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "release_version")
}
host.version.Set(version)
case "peer":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "peer")
}
host.peer = net.ParseIP(ip)
case "cluster_name":
host.clusterName, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
}
case "partitioner":
host.partitioner, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "partitioner")
}
case "broadcast_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
}
host.broadcastAddress = net.ParseIP(ip)
case "preferred_ip":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
}
host.preferredIP = net.ParseIP(ip)
case "rpc_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
}
host.rpcAddress = net.ParseIP(ip)
case "native_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "native_address")
}
host.rpcAddress = net.ParseIP(ip)
case "listen_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "listen_address")
}
host.listenAddress = net.ParseIP(ip)
case "native_port":
native_port, ok := value.(int)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "native_port")
}
host.port = native_port
case "workload":
host.workload, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "workload")
}
case "graph":
host.graph, ok = value.(bool)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "graph")
}
case "tokens":
host.tokens, ok = value.([]string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "tokens")
}
case "dse_version":
host.dseVersion, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
}
case "schema_version":
schemaVersion, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
}
host.schemaVersion = schemaVersion.String()
}
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
// Not sure what the port field will be called until the JIRA issue is complete
}
host.untranslatedConnectAddress = host.ConnectAddress()
ip, port := translateAddressPort(host.untranslatedConnectAddress, host.port)
host.connectAddress = ip
host.port = port
return host, nil
}
func hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
rows, err := iter.SliceMap()
if err != nil {
// TODO(zariel): make typed error
return nil, err
}
if len(rows) == 0 {
return nil, errors.New("query returned 0 rows")
}
host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, translateAddressPort)
if err != nil {
return nil, err
}
return host, nil
}
// debounceRingRefresh submits a ring refresh request to the ring refresh debouncer.
func (s *Session) debounceRingRefresh() {
s.ringRefresher.Debounce()
}
// refreshRing executes a ring refresh immediately and cancels pending debounce ring refresh requests.
func (s *Session) refreshRingNow() error {
err, ok := <-s.ringRefresher.RefreshNow()
if !ok {
return errors.New("could not refresh ring because stop was requested")
}
return err
}
func (s *Session) refreshRing() error {
hosts, partitioner, err := s.hostSource.GetHostsFromSystem()
if err != nil {
return err
}
prevHosts := s.hostSource.getHostsMap()
for _, h := range hosts {
if s.cfg.filterHost(h) {
continue
}
if host, ok := s.hostSource.addHostIfMissing(h); !ok {
s.startPoolFill(h)
} else {
// host (by hostID) already exists; determine if IP has changed
newHostID := h.HostID()
existing, ok := prevHosts[newHostID]
if !ok {
return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost)
}
if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) {
// no host IP change
host.update(h)
} else {
// host IP has changed
// remove old HostInfo (w/old IP)
s.removeHost(existing)
if _, alreadyExists := s.hostSource.addHostIfMissing(h); alreadyExists {
return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists)
}
// add new HostInfo (same hostID, new IP)
s.startPoolFill(h)
}
}
delete(prevHosts, h.HostID())
}
for _, host := range prevHosts {
s.metadataDescriber.removeTabletsWithHost(host)
s.removeHost(host)
}
s.policy.SetPartitioner(partitioner)
return nil
}

46
vendor/github.com/gocql/gocql/host_source_gen.go generated vendored Normal file
View File

@@ -0,0 +1,46 @@
//go:build genhostinfo
// +build genhostinfo
package main
import (
"fmt"
"reflect"
"sync"
"github.com/gocql/gocql"
)
func gen(clause, field string) {
fmt.Printf("if h.%s == %s {\n", field, clause)
fmt.Printf("\th.%s = from.%s\n", field, field)
fmt.Println("}")
}
func main() {
t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
mu := reflect.TypeOf(sync.RWMutex{})
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type == mu {
continue
}
switch f.Type.Kind() {
case reflect.Slice:
gen("nil", f.Name)
case reflect.String:
gen(`""`, f.Name)
case reflect.Int:
gen("0", f.Name)
case reflect.Struct:
gen("("+f.Type.Name()+"{})", f.Name)
case reflect.Bool, reflect.Int32:
continue
default:
panic(fmt.Sprintf("unknown field: %s", f))
}
}
}

7
vendor/github.com/gocql/gocql/host_source_scylla.go generated vendored Normal file
View File

@@ -0,0 +1,7 @@
package gocql
func (h *HostInfo) SetDatacenter(dc string) {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dc
}

0
vendor/github.com/gocql/gocql/install_test_deps.sh generated vendored Normal file
View File

54
vendor/github.com/gocql/gocql/integration.sh generated vendored Normal file
View File

@@ -0,0 +1,54 @@
#!/bin/bash
#
# Copyright (C) 2017 ScyllaDB
#
readonly SCYLLA_IMAGE=${SCYLLA_IMAGE}
set -eu -o pipefail
function scylla_up() {
local -r exec="docker compose exec -T"
echo "==> Running Scylla ${SCYLLA_IMAGE}"
docker pull ${SCYLLA_IMAGE}
docker compose up -d --wait || ( docker compose ps --format json | jq -M 'select(.Health == "unhealthy") | .Service' | xargs docker compose logs; exit 1 )
}
function scylla_down() {
echo "==> Stopping Scylla"
docker compose down
}
function scylla_restart() {
scylla_down
scylla_up
}
scylla_restart
sudo chmod 0777 /tmp/scylla/cql.m
readonly clusterSize=1
readonly multiNodeClusterSize=3
readonly scylla_liveset="192.168.100.11"
readonly scylla_tablet_liveset="192.168.100.12"
readonly cversion="3.11.4"
readonly proto=4
readonly args="-gocql.timeout=60s -proto=${proto} -rf=${clusterSize} -clusterSize=${clusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -cluster=${scylla_liveset}"
readonly tabletArgs="-gocql.timeout=60s -proto=${proto} -rf=1 -clusterSize=${multiNodeClusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -multiCluster=${scylla_tablet_liveset}"
if [[ "$*" == *"tablet"* ]];
then
echo "==> Running tablet tests with args: ${tabletArgs}"
go test -timeout=5m -race -tags="tablet" ${tabletArgs} ./...
fi
TAGS=$*
TAGS=${TAGS//"tablet"/}
if [ ! -z "$TAGS" ];
then
echo "==> Running ${TAGS} tests with args: ${args}"
go test -timeout=5m -race -tags="$TAGS" ${args} ./...
fi

127
vendor/github.com/gocql/gocql/internal/lru/lru.go generated vendored Normal file
View File

@@ -0,0 +1,127 @@
/*
Copyright 2015 To gocql authors
Copyright 2013 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package lru implements an LRU cache.
package lru
import "container/list"
// Cache is an LRU cache. It is not safe for concurrent access.
//
// This cache has been forked from github.com/golang/groupcache/lru, but
// specialized with string keys to avoid the allocations caused by wrapping them
// in interface{}.
type Cache struct {
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int
// OnEvicted optionally specifies a callback function to be
// executed when an entry is purged from the cache.
OnEvicted func(key string, value interface{})
ll *list.List
cache map[string]*list.Element
}
type entry struct {
key string
value interface{}
}
// New creates a new Cache.
// If maxEntries is zero, the cache has no limit and it's assumed
// that eviction is done by the caller.
func New(maxEntries int) *Cache {
return &Cache{
MaxEntries: maxEntries,
ll: list.New(),
cache: make(map[string]*list.Element),
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value interface{}) {
if c.cache == nil {
c.cache = make(map[string]*list.Element)
c.ll = list.New()
}
if ee, ok := c.cache[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry).value = value
return
}
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
c.RemoveOldest()
}
}
// Get looks up a key's value from the cache.
func (c *Cache) Get(key string) (value interface{}, ok bool) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry).value, true
}
return
}
// Remove removes the provided key from the cache.
func (c *Cache) Remove(key string) bool {
if c.cache == nil {
return false
}
if ele, hit := c.cache[key]; hit {
c.removeElement(ele)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *Cache) RemoveOldest() {
if c.cache == nil {
return
}
ele := c.ll.Back()
if ele != nil {
c.removeElement(ele)
}
}
func (c *Cache) removeElement(e *list.Element) {
c.ll.Remove(e)
kv := e.Value.(*entry)
delete(c.cache, kv.key)
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
// Len returns the number of items in the cache.
func (c *Cache) Len() int {
if c.cache == nil {
return 0
}
return c.ll.Len()
}

135
vendor/github.com/gocql/gocql/internal/murmur/murmur.go generated vendored Normal file
View File

@@ -0,0 +1,135 @@
package murmur
const (
c1 int64 = -8663945395140668459 // 0x87c37b91114253d5
c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f
fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd
fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
)
func fmix(n int64) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
n ^= int64(uint64(n) >> 33)
n *= fmix1
n ^= int64(uint64(n) >> 33)
n *= fmix2
n ^= int64(uint64(n) >> 33)
return n
}
func block(p byte) int64 {
return int64(int8(p))
}
func rotl(x int64, r uint8) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
return (x << r) | (int64)((uint64(x) >> (64 - r)))
}
func Murmur3H1(data []byte) int64 {
length := len(data)
var h1, h2, k1, k2 int64
// body
nBlocks := length / 16
for i := 0; i < nBlocks; i++ {
k1, k2 = getBlock(data, i)
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
h1 = rotl(h1, 27)
h1 += h2
h1 = h1*5 + 0x52dce729
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
h2 = rotl(h2, 31)
h2 += h1
h2 = h2*5 + 0x38495ab5
}
// tail
tail := data[nBlocks*16:]
k1 = 0
k2 = 0
switch length & 15 {
case 15:
k2 ^= block(tail[14]) << 48
fallthrough
case 14:
k2 ^= block(tail[13]) << 40
fallthrough
case 13:
k2 ^= block(tail[12]) << 32
fallthrough
case 12:
k2 ^= block(tail[11]) << 24
fallthrough
case 11:
k2 ^= block(tail[10]) << 16
fallthrough
case 10:
k2 ^= block(tail[9]) << 8
fallthrough
case 9:
k2 ^= block(tail[8])
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
fallthrough
case 8:
k1 ^= block(tail[7]) << 56
fallthrough
case 7:
k1 ^= block(tail[6]) << 48
fallthrough
case 6:
k1 ^= block(tail[5]) << 40
fallthrough
case 5:
k1 ^= block(tail[4]) << 32
fallthrough
case 4:
k1 ^= block(tail[3]) << 24
fallthrough
case 3:
k1 ^= block(tail[2]) << 16
fallthrough
case 2:
k1 ^= block(tail[1]) << 8
fallthrough
case 1:
k1 ^= block(tail[0])
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
}
h1 ^= int64(length)
h2 ^= int64(length)
h1 += h2
h2 += h1
h1 = fmix(h1)
h2 = fmix(h2)
h1 += h2
// the following is extraneous since h2 is discarded
// h2 += h1
return h1
}

View File

@@ -0,0 +1,12 @@
//go:build appengine || s390x
// +build appengine s390x
package murmur
import "encoding/binary"
func getBlock(data []byte, n int) (int64, int64) {
k1 := int64(binary.LittleEndian.Uint64(data[n*16:]))
k2 := int64(binary.LittleEndian.Uint64(data[(n*16)+8:]))
return k1, k2
}

View File

@@ -0,0 +1,16 @@
//go:build !appengine && !s390x
// +build !appengine,!s390x
package murmur
import (
"unsafe"
)
func getBlock(data []byte, n int) (int64, int64) {
block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
k1 := block[0]
k2 := block[1]
return k1, k2
}

View File

@@ -0,0 +1,146 @@
package streams
import (
"math"
"strconv"
"sync/atomic"
)
const bucketBits = 64
// IDGenerator tracks and allocates streams which are in use.
type IDGenerator struct {
NumStreams int
inuseStreams int32
numBuckets uint32
// streams is a bitset where each bit represents a stream, a 1 implies in use
streams []uint64
offset uint32
}
func New(protocol int) *IDGenerator {
maxStreams := 128
if protocol > 2 {
maxStreams = 32768
}
return NewLimited(maxStreams)
}
func NewLimited(maxStreams int) *IDGenerator {
// Round up maxStreams to a nearest
// multiple of 64
maxStreams = ((maxStreams + 63) / 64) * 64
buckets := maxStreams / 64
// reserve stream 0
streams := make([]uint64, buckets)
streams[0] = 1 << 63
return &IDGenerator{
NumStreams: maxStreams,
streams: streams,
numBuckets: uint32(buckets),
offset: uint32(buckets) - 1,
}
}
func streamFromBucket(bucket, streamInBucket int) int {
return (bucket * bucketBits) + streamInBucket
}
func (s *IDGenerator) GetStream() (int, bool) {
// Reduce collisions by offsetting the starting point
offset := atomic.AddUint32(&s.offset, 1)
for i := uint32(0); i < s.numBuckets; i++ {
pos := int((i + offset) % s.numBuckets)
bucket := atomic.LoadUint64(&s.streams[pos])
if bucket == math.MaxUint64 {
// all streams in use
continue
}
for j := 0; j < bucketBits; j++ {
mask := uint64(1 << streamOffset(j))
for bucket&mask == 0 {
if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
atomic.AddInt32(&s.inuseStreams, 1)
return streamFromBucket(int(pos), j), true
}
bucket = atomic.LoadUint64(&s.streams[pos])
}
}
}
return 0, false
}
func bitfmt(b uint64) string {
return strconv.FormatUint(b, 16)
}
// returns the bucket offset of a given stream
func bucketOffset(i int) int {
return i / bucketBits
}
func streamOffset(stream int) uint64 {
return bucketBits - uint64(stream%bucketBits) - 1
}
func isSet(bits uint64, stream int) bool {
return bits>>streamOffset(stream)&1 == 1
}
func (s *IDGenerator) isSet(stream int) bool {
bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
return isSet(bits, stream)
}
func (s *IDGenerator) String() string {
size := s.numBuckets * (bucketBits + 1)
buf := make([]byte, 0, size)
for i := 0; i < int(s.numBuckets); i++ {
bits := atomic.LoadUint64(&s.streams[i])
buf = append(buf, bitfmt(bits)...)
buf = append(buf, ' ')
}
return string(buf[: size-1 : size-1])
}
func (s *IDGenerator) Clear(stream int) (inuse bool) {
offset := bucketOffset(stream)
bucket := atomic.LoadUint64(&s.streams[offset])
mask := uint64(1) << streamOffset(stream)
if bucket&mask != mask {
// already cleared
return false
}
for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
bucket = atomic.LoadUint64(&s.streams[offset])
if bucket&mask != mask {
// already cleared
return false
}
}
// TODO: make this account for 0 stream being reserved
if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
// TODO(zariel): remove this
panic("negative streams inuse")
}
return true
}
func (s *IDGenerator) Available() int {
return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
}
func (s *IDGenerator) InUse() int {
return int(atomic.LoadInt32(&s.inuseStreams))
}

40
vendor/github.com/gocql/gocql/logger.go generated vendored Normal file
View File

@@ -0,0 +1,40 @@
package gocql
import (
"bytes"
"fmt"
"log"
)
type StdLogger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
Println(v ...interface{})
}
type nopLogger struct{}
func (n nopLogger) Print(_ ...interface{}) {}
func (n nopLogger) Printf(_ string, _ ...interface{}) {}
func (n nopLogger) Println(_ ...interface{}) {}
type testLogger struct {
capture bytes.Buffer
}
func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) }
func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) }
func (l *testLogger) String() string { return l.capture.String() }
type defaultLogger struct{}
func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) }
func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) }
func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) }
// Logger for logging messages.
// Deprecated: Use ClusterConfig.Logger instead.
var Logger StdLogger = &defaultLogger{}

1846
vendor/github.com/gocql/gocql/marshal.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1466
vendor/github.com/gocql/gocql/metadata_cassandra.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1102
vendor/github.com/gocql/gocql/metadata_scylla.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1326
vendor/github.com/gocql/gocql/policies.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

78
vendor/github.com/gocql/gocql/prepared_cache.go generated vendored Normal file
View File

@@ -0,0 +1,78 @@
package gocql
import (
"bytes"
"sync"
"github.com/gocql/gocql/internal/lru"
)
const defaultMaxPreparedStmts = 1000
// preparedLRU is the prepared statement cache
type preparedLRU struct {
mu sync.Mutex
lru *lru.Cache
}
func (p *preparedLRU) clear() {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > 0 {
p.lru.RemoveOldest()
}
}
func (p *preparedLRU) add(key string, val *inflightPrepare) {
p.mu.Lock()
defer p.mu.Unlock()
p.lru.Add(key, val)
}
func (p *preparedLRU) remove(key string) bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.lru.Remove(key)
}
func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
p.mu.Lock()
defer p.mu.Unlock()
val, ok := p.lru.Get(key)
if ok {
return val.(*inflightPrepare), true
}
return fn(p.lru), false
}
func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string {
// TODO: we should just use a struct for the key in the map
return hostID + keyspace + statement
}
func (p *preparedLRU) evictPreparedID(key string, id []byte) {
p.mu.Lock()
defer p.mu.Unlock()
val, ok := p.lru.Get(key)
if !ok {
return
}
ifp, ok := val.(*inflightPrepare)
if !ok {
return
}
select {
case <-ifp.done:
if bytes.Equal(id, ifp.preparedStatment.id) {
p.lru.Remove(key)
}
default:
}
}

238
vendor/github.com/gocql/gocql/query_executor.go generated vendored Normal file
View File

@@ -0,0 +1,238 @@
package gocql
import (
"context"
"errors"
"sync"
"time"
)
type ExecutableQuery interface {
borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine.
releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error.
execute(ctx context.Context, conn *Conn) *Iter
attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
retryPolicy() RetryPolicy
speculativeExecutionPolicy() SpeculativeExecutionPolicy
GetRoutingKey() ([]byte, error)
Keyspace() string
Table() string
IsIdempotent() bool
IsLWT() bool
GetCustomPartitioner() Partitioner
withContext(context.Context) ExecutableQuery
RetryableQuery
GetSession() *Session
}
type queryExecutor struct {
pool *policyConnPool
policy HostSelectionPolicy
}
func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
start := time.Now()
iter := qry.execute(ctx, conn)
end := time.Now()
qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
return iter
}
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
hostIter NextHost, results chan *Iter) *Iter {
ticker := time.NewTicker(sp.Delay())
defer ticker.Stop()
for i := 0; i < sp.Attempts(); i++ {
select {
case <-ticker.C:
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
go q.run(ctx, qry, hostIter, results)
case <-ctx.Done():
return &Iter{err: ctx.Err()}
case iter := <-results:
return iter
}
}
return nil
}
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry, hostIter), nil
}
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
var mu sync.Mutex
origHostIter := hostIter
hostIter = func() SelectedHost {
mu.Lock()
defer mu.Unlock()
return origHostIter()
}
ctx, cancel := context.WithCancel(qry.Context())
defer cancel()
results := make(chan *Iter, 1)
// Launch the main execution
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
go q.run(ctx, qry, hostIter, results)
// The speculative executions are launched _in addition_ to the main
// execution, on a timer. So Speculation{2} would make 3 executions running
// in total.
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
return iter, nil
}
select {
case iter := <-results:
return iter, nil
case <-ctx.Done():
return &Iter{err: ctx.Err()}, nil
}
}
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
rt := qry.retryPolicy()
if rt == nil {
rt = &SimpleRetryPolicy{3}
}
lwtRT, isRTSupportsLWT := rt.(LWTRetryPolicy)
var getShouldRetry func(qry RetryableQuery) bool
var getRetryType func(error) RetryType
if isRTSupportsLWT && qry.IsLWT() {
getShouldRetry = lwtRT.AttemptLWT
getRetryType = lwtRT.GetRetryTypeLWT
} else {
getShouldRetry = rt.Attempt
getRetryType = rt.GetRetryType
}
var potentiallyExecuted bool
execute := func(qry ExecutableQuery, selectedHost SelectedHost) (iter *Iter, retry RetryType) {
host := selectedHost.Info()
if host == nil || !host.IsUp() {
return &Iter{
err: &QueryError{
err: ErrHostDown,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
pool, ok := q.pool.getPool(host)
if !ok {
return &Iter{
err: &QueryError{
err: ErrNoPool,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
conn := pool.Pick(selectedHost.Token(), qry)
if conn == nil {
return &Iter{
err: &QueryError{
err: ErrNoConnectionsInPool,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
iter = q.attemptQuery(ctx, qry, conn)
iter.host = selectedHost.Info()
// Update host
if iter.err == nil {
return iter, RetryType(255)
}
switch {
case errors.Is(iter.err, context.Canceled),
errors.Is(iter.err, context.DeadlineExceeded):
selectedHost.Mark(nil)
potentiallyExecuted = true
retry = Rethrow
default:
selectedHost.Mark(iter.err)
retry = RetryType(255) // Don't enforce retry and get it from retry policy
}
var qErr *QueryError
if errors.As(iter.err, &qErr) {
potentiallyExecuted = potentiallyExecuted && qErr.PotentiallyExecuted()
qErr.potentiallyExecuted = potentiallyExecuted
qErr.isIdempotent = qry.IsIdempotent()
iter.err = qErr
} else {
iter.err = &QueryError{
err: iter.err,
potentiallyExecuted: potentiallyExecuted,
isIdempotent: qry.IsIdempotent(),
}
}
return iter, retry
}
var lastErr error
selectedHost := hostIter()
for selectedHost != nil {
iter, retryType := execute(qry, selectedHost)
if iter.err == nil {
return iter
}
lastErr = iter.err
// Exit if retry policy decides to not retry anymore
if retryType == RetryType(255) {
if !getShouldRetry(qry) {
return iter
}
retryType = getRetryType(iter.err)
}
// If query is unsuccessful, check the error with RetryPolicy to retry
switch retryType {
case Retry:
// retry on the same host
continue
case Rethrow, Ignore:
return iter
case RetryNextHost:
// retry on the next host
selectedHost = hostIter()
continue
default:
// Undefined? Return nil and error, this will panic in the requester
return &Iter{err: ErrUnknownRetryType}
}
}
if lastErr != nil {
return &Iter{err: lastErr}
}
return &Iter{err: ErrNoConnections}
}
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
select {
case results <- q.do(ctx, qry, hostIter):
case <-ctx.Done():
}
qry.releaseAfterExecution()
}

544
vendor/github.com/gocql/gocql/recreate.go generated vendored Normal file
View File

@@ -0,0 +1,544 @@
//go:build !cassandra
// +build !cassandra
// Copyright (C) 2017 ScyllaDB
package gocql
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"sort"
"strconv"
"strings"
"text/template"
)
// ToCQL returns a CQL query that ca be used to recreate keyspace with all
// user defined types, tables, indexes, functions, aggregates and views associated
// with this keyspace.
func (km *KeyspaceMetadata) ToCQL() (string, error) {
// Be aware that `CreateStmts` is not only a cache for ToCQL,
// but it also can be populated from response to `DESCRIBE KEYSPACE %s WITH INTERNALS`
if len(km.CreateStmts) != 0 {
return km.CreateStmts, nil
}
var sb strings.Builder
if err := km.keyspaceToCQL(&sb); err != nil {
return "", err
}
sortedTypes := km.typesSortedTopologically()
for _, tm := range sortedTypes {
if err := km.userTypeToCQL(&sb, tm); err != nil {
return "", err
}
}
for _, tm := range km.Tables {
if err := km.tableToCQL(&sb, km.Name, tm); err != nil {
return "", err
}
}
for _, im := range km.Indexes {
if err := km.indexToCQL(&sb, im); err != nil {
return "", err
}
}
for _, fm := range km.Functions {
if err := km.functionToCQL(&sb, km.Name, fm); err != nil {
return "", err
}
}
for _, am := range km.Aggregates {
if err := km.aggregateToCQL(&sb, am); err != nil {
return "", err
}
}
for _, vm := range km.Views {
if err := km.viewToCQL(&sb, vm); err != nil {
return "", err
}
}
km.CreateStmts = sb.String()
return km.CreateStmts, nil
}
func (km *KeyspaceMetadata) typesSortedTopologically() []*TypeMetadata {
sortedTypes := make([]*TypeMetadata, 0, len(km.Types))
for _, tm := range km.Types {
sortedTypes = append(sortedTypes, tm)
}
sort.Slice(sortedTypes, func(i, j int) bool {
for _, ft := range sortedTypes[j].FieldTypes {
if strings.Contains(ft, sortedTypes[i].Name) {
return true
}
}
return false
})
return sortedTypes
}
var tableCQLTemplate = template.Must(template.New("table").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"tableColumnToCQL": cqlHelpers.tableColumnToCQL,
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
}).
Parse(`
CREATE TABLE {{ .KeyspaceName }}.{{ .Tm.Name }} (
{{ tableColumnToCQL .Tm }}
) WITH {{ tablePropertiesToCQL .Tm.ClusteringColumns .Tm.Options .Tm.Flags .Tm.Extensions }};
`))
func (km *KeyspaceMetadata) tableToCQL(w io.Writer, kn string, tm *TableMetadata) error {
if err := tableCQLTemplate.Execute(w, map[string]interface{}{
"Tm": tm,
"KeyspaceName": kn,
}); err != nil {
return err
}
return nil
}
var functionTemplate = template.Must(template.New("functions").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"zip": cqlHelpers.zip,
"stripFrozen": cqlHelpers.stripFrozen,
}).
Parse(`
CREATE FUNCTION {{ .keyspaceName }}.{{ .fm.Name }} (
{{- range $i, $args := zip .fm.ArgumentNames .fm.ArgumentTypes }}
{{- if ne $i 0 }}, {{ end }}
{{- (index $args 0) }}
{{ stripFrozen (index $args 1) }}
{{- end -}})
{{ if .fm.CalledOnNullInput }}CALLED{{ else }}RETURNS NULL{{ end }} ON NULL INPUT
RETURNS {{ .fm.ReturnType }}
LANGUAGE {{ .fm.Language }}
AS $${{ .fm.Body }}$$;
`))
func (km *KeyspaceMetadata) functionToCQL(w io.Writer, keyspaceName string, fm *FunctionMetadata) error {
if err := functionTemplate.Execute(w, map[string]interface{}{
"fm": fm,
"keyspaceName": keyspaceName,
}); err != nil {
return err
}
return nil
}
var viewTemplate = template.Must(template.New("views").
Funcs(map[string]interface{}{
"zip": cqlHelpers.zip,
"partitionKeyString": cqlHelpers.partitionKeyString,
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
}).
Parse(`
CREATE MATERIALIZED VIEW {{ .vm.KeyspaceName }}.{{ .vm.ViewName }} AS
SELECT {{ if .vm.IncludeAllColumns }}*{{ else }}
{{- range $i, $col := .vm.OrderedColumns }}
{{- if ne $i 0 }}, {{ end }}
{{ $col }}
{{- end }}
{{- end }}
FROM {{ .vm.KeyspaceName }}.{{ .vm.BaseTableName }}
WHERE {{ .vm.WhereClause }}
PRIMARY KEY ({{ partitionKeyString .vm.PartitionKey .vm.ClusteringColumns }})
WITH {{ tablePropertiesToCQL .vm.ClusteringColumns .vm.Options .flags .vm.Extensions }};
`))
func (km *KeyspaceMetadata) viewToCQL(w io.Writer, vm *ViewMetadata) error {
if err := viewTemplate.Execute(w, map[string]interface{}{
"vm": vm,
"flags": []string{},
}); err != nil {
return err
}
return nil
}
var aggregatesTemplate = template.Must(template.New("aggregate").
Funcs(map[string]interface{}{
"stripFrozen": cqlHelpers.stripFrozen,
}).
Parse(`
CREATE AGGREGATE {{ .Keyspace }}.{{ .Name }}(
{{- range $i, $arg := .ArgumentTypes }}
{{- if ne $i 0 }}, {{ end }}
{{ stripFrozen $arg }}
{{- end -}})
SFUNC {{ .StateFunc.Name }}
STYPE {{ stripFrozen .StateType }}
{{- if ne .FinalFunc.Name "" }}
FINALFUNC {{ .FinalFunc.Name }}
{{- end -}}
{{- if ne .InitCond "" }}
INITCOND {{ .InitCond }}
{{- end -}}
;
`))
func (km *KeyspaceMetadata) aggregateToCQL(w io.Writer, am *AggregateMetadata) error {
if err := aggregatesTemplate.Execute(w, am); err != nil {
return err
}
return nil
}
var typeCQLTemplate = template.Must(template.New("types").
Funcs(map[string]interface{}{
"zip": cqlHelpers.zip,
}).
Parse(`
CREATE TYPE {{ .Keyspace }}.{{ .Name }} (
{{- range $i, $fields := zip .FieldNames .FieldTypes }} {{- if ne $i 0 }},{{ end }}
{{ index $fields 0 }} {{ index $fields 1 }}
{{- end }}
);
`))
func (km *KeyspaceMetadata) userTypeToCQL(w io.Writer, tm *TypeMetadata) error {
if err := typeCQLTemplate.Execute(w, tm); err != nil {
return err
}
return nil
}
func (km *KeyspaceMetadata) indexToCQL(w io.Writer, im *IndexMetadata) error {
// Scylla doesn't support any custom indexes
if im.Kind == IndexKindCustom {
return nil
}
options := im.Options
indexTarget := options["target"]
// secondary index
si := struct {
ClusteringKeys []string `json:"ck"`
PartitionKeys []string `json:"pk"`
}{}
if err := json.Unmarshal([]byte(indexTarget), &si); err == nil {
indexTarget = fmt.Sprintf("(%s), %s",
strings.Join(si.PartitionKeys, ","),
strings.Join(si.ClusteringKeys, ","),
)
}
_, err := fmt.Fprintf(w, "\nCREATE INDEX %s ON %s.%s (%s);\n",
im.Name,
im.KeyspaceName,
im.TableName,
indexTarget,
)
if err != nil {
return err
}
return nil
}
var keyspaceCQLTemplate = template.Must(template.New("keyspace").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"fixStrategy": cqlHelpers.fixStrategy,
}).
Parse(`CREATE KEYSPACE {{ .Name }} WITH replication = {
'class': {{ escape ( fixStrategy .StrategyClass) }}
{{- range $key, $value := .StrategyOptions }},
{{ escape $key }}: {{ escape $value }}
{{- end }}
}{{ if not .DurableWrites }} AND durable_writes = 'false'{{ end }};
`))
func (km *KeyspaceMetadata) keyspaceToCQL(w io.Writer) error {
if err := keyspaceCQLTemplate.Execute(w, km); err != nil {
return err
}
return nil
}
func contains(in []string, v string) bool {
for _, e := range in {
if e == v {
return true
}
}
return false
}
type toCQLHelpers struct{}
var cqlHelpers = toCQLHelpers{}
func (h toCQLHelpers) zip(a []string, b []string) [][]string {
m := make([][]string, len(a))
for i := range a {
m[i] = []string{a[i], b[i]}
}
return m
}
func (h toCQLHelpers) escape(e interface{}) string {
switch v := e.(type) {
case int, float64:
return fmt.Sprint(v)
case bool:
if v {
return "true"
}
return "false"
case string:
return "'" + strings.ReplaceAll(v, "'", "''") + "'"
case []byte:
return string(v)
}
return ""
}
func (h toCQLHelpers) stripFrozen(v string) string {
return strings.TrimSuffix(strings.TrimPrefix(v, "frozen<"), ">")
}
func (h toCQLHelpers) fixStrategy(v string) string {
return strings.TrimPrefix(v, "org.apache.cassandra.locator.")
}
func (h toCQLHelpers) fixQuote(v string) string {
return strings.ReplaceAll(v, `"`, `'`)
}
func (h toCQLHelpers) tableOptionsToCQL(ops TableMetadataOptions) ([]string, error) {
opts := map[string]interface{}{
"bloom_filter_fp_chance": ops.BloomFilterFpChance,
"comment": ops.Comment,
"crc_check_chance": ops.CrcCheckChance,
"default_time_to_live": ops.DefaultTimeToLive,
"gc_grace_seconds": ops.GcGraceSeconds,
"max_index_interval": ops.MaxIndexInterval,
"memtable_flush_period_in_ms": ops.MemtableFlushPeriodInMs,
"min_index_interval": ops.MinIndexInterval,
"speculative_retry": ops.SpeculativeRetry,
}
var err error
opts["caching"], err = json.Marshal(ops.Caching)
if err != nil {
return nil, err
}
opts["compaction"], err = json.Marshal(ops.Compaction)
if err != nil {
return nil, err
}
opts["compression"], err = json.Marshal(ops.Compression)
if err != nil {
return nil, err
}
cdc, err := json.Marshal(ops.CDC)
if err != nil {
return nil, err
}
if string(cdc) != "null" {
opts["cdc"] = cdc
}
if ops.InMemory {
opts["in_memory"] = ops.InMemory
}
out := make([]string, 0, len(opts))
for key, opt := range opts {
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(opt))))
}
sort.Strings(out)
return out, nil
}
func (h toCQLHelpers) tableExtensionsToCQL(extensions map[string]interface{}) ([]string, error) {
exts := map[string]interface{}{}
if blob, ok := extensions["scylla_encryption_options"]; ok {
encOpts := &scyllaEncryptionOptions{}
if err := encOpts.UnmarshalBinary(blob.([]byte)); err != nil {
return nil, err
}
var err error
exts["scylla_encryption_options"], err = json.Marshal(encOpts)
if err != nil {
return nil, err
}
}
out := make([]string, 0, len(exts))
for key, ext := range exts {
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(ext))))
}
sort.Strings(out)
return out, nil
}
func (h toCQLHelpers) tablePropertiesToCQL(cks []*ColumnMetadata, opts TableMetadataOptions,
flags []string, extensions map[string]interface{}) (string, error) {
var sb strings.Builder
var properties []string
compactStorage := len(flags) > 0 && (contains(flags, TableFlagDense) ||
contains(flags, TableFlagSuper) ||
!contains(flags, TableFlagCompound))
if compactStorage {
properties = append(properties, "COMPACT STORAGE")
}
if len(cks) > 0 {
var inner []string
for _, col := range cks {
inner = append(inner, fmt.Sprintf("%s %s", col.Name, col.ClusteringOrder))
}
properties = append(properties, fmt.Sprintf("CLUSTERING ORDER BY (%s)", strings.Join(inner, ", ")))
}
options, err := h.tableOptionsToCQL(opts)
if err != nil {
return "", err
}
properties = append(properties, options...)
exts, err := h.tableExtensionsToCQL(extensions)
if err != nil {
return "", err
}
properties = append(properties, exts...)
sb.WriteString(strings.Join(properties, "\n AND "))
return sb.String(), nil
}
func (h toCQLHelpers) tableColumnToCQL(tm *TableMetadata) string {
var sb strings.Builder
var columns []string
for _, cn := range tm.OrderedColumns {
cm := tm.Columns[cn]
column := fmt.Sprintf("%s %s", cn, cm.Type)
if cm.Kind == ColumnStatic {
column += " static"
}
columns = append(columns, column)
}
if len(tm.PartitionKey) == 1 && len(tm.ClusteringColumns) == 0 && len(columns) > 0 {
columns[0] += " PRIMARY KEY"
}
sb.WriteString(strings.Join(columns, ",\n "))
if len(tm.PartitionKey) > 1 || len(tm.ClusteringColumns) > 0 {
sb.WriteString(",\n PRIMARY KEY (")
sb.WriteString(h.partitionKeyString(tm.PartitionKey, tm.ClusteringColumns))
sb.WriteRune(')')
}
return sb.String()
}
func (h toCQLHelpers) partitionKeyString(pks, cks []*ColumnMetadata) string {
var sb strings.Builder
if len(pks) > 1 {
sb.WriteRune('(')
for i, pk := range pks {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(pk.Name)
}
sb.WriteRune(')')
} else {
sb.WriteString(pks[0].Name)
}
if len(cks) > 0 {
sb.WriteString(", ")
for i, ck := range cks {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(ck.Name)
}
}
return sb.String()
}
type scyllaEncryptionOptions struct {
CipherAlgorithm string `json:"cipher_algorithm"`
SecretKeyStrength int `json:"secret_key_strength"`
KeyProvider string `json:"key_provider"`
SecretKeyFile string `json:"secret_key_file"`
}
// UnmarshalBinary deserializes blob into scyllaEncryptionOptions.
// Format:
// - 4 bytes - size of KV map
// Size times:
// - 4 bytes - length of key
// - len_of_key bytes - key
// - 4 bytes - length of value
// - len_of_value bytes - value
func (enc *scyllaEncryptionOptions) UnmarshalBinary(data []byte) error {
size := binary.LittleEndian.Uint32(data[0:4])
m := make(map[string]string, size)
off := uint32(4)
for i := uint32(0); i < size; i++ {
keyLen := binary.LittleEndian.Uint32(data[off : off+4])
off += 4
key := string(data[off : off+keyLen])
off += keyLen
valueLen := binary.LittleEndian.Uint32(data[off : off+4])
off += 4
value := string(data[off : off+valueLen])
off += valueLen
m[key] = value
}
enc.CipherAlgorithm = m["cipher_algorithm"]
enc.KeyProvider = m["key_provider"]
enc.SecretKeyFile = m["secret_key_file"]
if secretKeyStrength, ok := m["secret_key_strength"]; ok {
sks, err := strconv.Atoi(secretKeyStrength)
if err != nil {
return err
}
enc.SecretKeyStrength = sks
}
return nil
}

275
vendor/github.com/gocql/gocql/ring_describer.go generated vendored Normal file
View File

@@ -0,0 +1,275 @@
package gocql
import (
"context"
"errors"
"fmt"
"net"
"sync"
)
// Polls system.peers at a specific interval to find new hosts
type ringDescriber struct {
control controlConnection
cfg *ClusterConfig
logger StdLogger
mu sync.RWMutex
prevHosts []*HostInfo
prevPartitioner string
// hosts are the set of all hosts in the cassandra ring that we know of.
// key of map is host_id.
hosts map[string]*HostInfo
// hostIPToUUID maps host native address to host_id.
hostIPToUUID map[string]string
}
func (r *ringDescriber) setControlConn(c controlConnection) {
r.mu.Lock()
defer r.mu.Unlock()
r.control = c
}
// Ask the control node for the local host info
func (r *ringDescriber) getLocalHostInfo(conn ConnInterface) (*HostInfo, error) {
iter := conn.querySystem(context.TODO(), qrySystemLocal)
if iter == nil {
return nil, errNoControl
}
host, err := hostInfoFromIter(iter, nil, r.cfg.Port, r.cfg.translateAddressPort)
if err != nil {
return nil, fmt.Errorf("could not retrieve local host info: %w", err)
}
return host, nil
}
// Ask the control node for host info on all it's known peers
func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo, c ConnInterface) ([]*HostInfo, error) {
var iter *Iter
if c.getIsSchemaV2() {
iter = c.querySystem(context.TODO(), qrySystemPeersV2)
} else {
iter = c.querySystem(context.TODO(), qrySystemPeers)
}
if iter == nil {
return nil, errNoControl
}
rows, err := iter.SliceMap()
if err != nil {
// TODO(zariel): make typed error
return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
}
return getPeersFromQuerySystemPeers(rows, r.cfg.Port, r.cfg.translateAddressPort, r.logger)
}
func getPeersFromQuerySystemPeers(querySystemPeerRows []map[string]interface{}, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) ([]*HostInfo, error) {
var peers []*HostInfo
for _, row := range querySystemPeerRows {
// extract all available info about the peer
host, err := hostInfoFromMap(row, &HostInfo{port: port}, translateAddressPort)
if err != nil {
return nil, err
} else if !isValidPeer(host) {
// If it's not a valid peer
logger.Printf("Found invalid peer '%s' "+
"Likely due to a gossip or snitch issue, this host will be ignored", host)
continue
} else if isZeroToken(host) {
continue
}
peers = append(peers, host)
}
return peers, nil
}
// Return true if the host is a valid peer
func isValidPeer(host *HostInfo) bool {
return !(len(host.RPCAddress()) == 0 ||
host.hostId == "" ||
host.dataCenter == "" ||
host.rack == "")
}
func isZeroToken(host *HostInfo) bool {
return len(host.tokens) == 0
}
// GetHostsFromSystem returns a list of hosts found via queries to system.local and system.peers
func (r *ringDescriber) GetHostsFromSystem() ([]*HostInfo, string, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.control == nil {
return r.prevHosts, r.prevPartitioner, errNoControl
}
ch := r.control.getConn()
localHost, err := r.getLocalHostInfo(ch.conn)
if err != nil {
return r.prevHosts, r.prevPartitioner, err
}
peerHosts, err := r.getClusterPeerInfo(localHost, ch.conn)
if err != nil {
return r.prevHosts, r.prevPartitioner, err
}
var hosts []*HostInfo
if !isZeroToken(localHost) {
hosts = []*HostInfo{localHost}
}
hosts = append(hosts, peerHosts...)
var partitioner string
if len(hosts) > 0 {
partitioner = hosts[0].Partitioner()
}
r.prevHosts = hosts
r.prevPartitioner = partitioner
return hosts, partitioner, nil
}
// Given an ip/port return HostInfo for the specified ip/port
func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) {
var host *HostInfo
for _, table := range []string{"system.peers", "system.local"} {
ch := r.control.getConn()
var iter *Iter
if ch.host.HostID() == hostID.String() {
host = ch.host
iter = nil
}
if table == "system.peers" {
if ch.conn.getIsSchemaV2() {
iter = ch.conn.querySystem(context.TODO(), qrySystemPeersV2)
} else {
iter = ch.conn.querySystem(context.TODO(), qrySystemPeers)
}
} else {
iter = ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table))
}
if iter != nil {
rows, err := iter.SliceMap()
if err != nil {
return nil, err
}
for _, row := range rows {
h, err := hostInfoFromMap(row, &HostInfo{port: r.cfg.Port}, r.cfg.translateAddressPort)
if err != nil {
return nil, err
}
if h.HostID() == hostID.String() {
host = h
break
}
}
}
}
if host == nil {
return nil, errors.New("unable to fetch host info: invalid control connection")
} else if host.invalidConnectAddr() {
return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", host.connectAddress, host)
}
return host, nil
}
func (r *ringDescriber) getHostByIP(ip string) (*HostInfo, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
hi, ok := r.hostIPToUUID[ip]
return r.hosts[hi], ok
}
func (r *ringDescriber) getHost(hostID string) *HostInfo {
r.mu.RLock()
host := r.hosts[hostID]
r.mu.RUnlock()
return host
}
func (r *ringDescriber) getHostsList() []*HostInfo {
r.mu.RLock()
hosts := make([]*HostInfo, 0, len(r.hosts))
for _, host := range r.hosts {
hosts = append(hosts, host)
}
r.mu.RUnlock()
return hosts
}
func (r *ringDescriber) getHostsMap() map[string]*HostInfo {
r.mu.RLock()
hosts := make(map[string]*HostInfo, len(r.hosts))
for k, v := range r.hosts {
hosts[k] = v
}
r.mu.RUnlock()
return hosts
}
func (r *ringDescriber) addOrUpdate(host *HostInfo) *HostInfo {
if existingHost, ok := r.addHostIfMissing(host); ok {
existingHost.update(host)
host = existingHost
}
return host
}
func (r *ringDescriber) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
hostID := host.HostID()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
if r.hostIPToUUID == nil {
r.hostIPToUUID = make(map[string]string)
}
existing, ok := r.hosts[hostID]
if !ok {
r.hosts[hostID] = host
r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID
existing = host
}
r.mu.Unlock()
return existing, ok
}
func (r *ringDescriber) removeHost(hostID string) bool {
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
if r.hostIPToUUID == nil {
r.hostIPToUUID = make(map[string]string)
}
h, ok := r.hosts[hostID]
if ok {
delete(r.hostIPToUUID, h.nodeToNodeAddress().String())
}
delete(r.hosts, hostID)
r.mu.Unlock()
return ok
}

874
vendor/github.com/gocql/gocql/scylla.go generated vendored Normal file
View File

@@ -0,0 +1,874 @@
package gocql
import (
"context"
"crypto/tls"
"errors"
"fmt"
"math"
"net"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
)
// scyllaSupported represents Scylla connection options as sent in SUPPORTED
// frame.
// FIXME: Should also follow `cqlProtocolExtension` interface.
type scyllaSupported struct {
shard int
nrShards int
msbIgnore uint64
partitioner string
shardingAlgorithm string
shardAwarePort uint16
shardAwarePortSSL uint16
lwtFlagMask int
}
// CQL Protocol extension interface for Scylla.
// Each extension is identified by a name and defines a way to serialize itself
// in STARTUP message payload.
type cqlProtocolExtension interface {
name() string
serialize() map[string]string
}
func findCQLProtoExtByName(exts []cqlProtocolExtension, name string) cqlProtocolExtension {
for i := range exts {
if exts[i].name() == name {
return exts[i]
}
}
return nil
}
// Top-level keys used for serialization/deserialization of CQL protocol
// extensions in SUPPORTED/STARTUP messages.
// Each key identifies a single extension.
const (
lwtAddMetadataMarkKey = "SCYLLA_LWT_ADD_METADATA_MARK"
rateLimitError = "SCYLLA_RATE_LIMIT_ERROR"
tabletsRoutingV1 = "TABLETS_ROUTING_V1"
)
// "tabletsRoutingV1" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to send a tablet information in `custom_payload`.
//
// Implements cqlProtocolExtension interface.
type tabletsRoutingV1Ext struct {
}
var _ cqlProtocolExtension = &tabletsRoutingV1Ext{}
// Factory function to deserialize and create an `tabletsRoutingV1Ext` instance
// from SUPPORTED message payload.
func newTabletsRoutingV1Ext(supported map[string][]string) *tabletsRoutingV1Ext {
if _, found := supported[tabletsRoutingV1]; found {
return &tabletsRoutingV1Ext{}
}
return nil
}
func (ext *tabletsRoutingV1Ext) serialize() map[string]string {
return map[string]string{
tabletsRoutingV1: "",
}
}
func (ext *tabletsRoutingV1Ext) name() string {
return tabletsRoutingV1
}
// "Rate limit" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to send a special kind of error.
//
// Implements cqlProtocolExtension interface.
type rateLimitExt struct {
rateLimitErrorCode int
}
var _ cqlProtocolExtension = &rateLimitExt{}
// Factory function to deserialize and create an `rateLimitExt` instance
// from SUPPORTED message payload.
func newRateLimitExt(supported map[string][]string) *rateLimitExt {
const rateLimitErrorCode = "ERROR_CODE"
if v, found := supported[rateLimitError]; found {
for i := range v {
splitVal := strings.Split(v[i], "=")
if splitVal[0] == rateLimitErrorCode {
var (
err error
errorCode int
)
if errorCode, err = strconv.Atoi(splitVal[1]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", rateLimitErrorCode, splitVal[1], err)
return nil
}
}
return &rateLimitExt{
rateLimitErrorCode: errorCode,
}
}
}
}
return nil
}
func (ext *rateLimitExt) serialize() map[string]string {
return map[string]string{
rateLimitError: "",
}
}
func (ext *rateLimitExt) name() string {
return rateLimitError
}
// "LWT prepared statements metadata mark" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to set a special bit in prepared statements metadata, which would indicate
// whether the statement at hand is LWT statement or not.
//
// This is further used to consistently choose primary replicas in a predefined
// order for these queries, which can reduce contention over hot keys and thus
// increase LWT performance.
//
// Implements cqlProtocolExtension interface.
type lwtAddMetadataMarkExt struct {
lwtOptMetaBitMask int
}
var _ cqlProtocolExtension = &lwtAddMetadataMarkExt{}
// Factory function to deserialize and create an `lwtAddMetadataMarkExt` instance
// from SUPPORTED message payload.
func newLwtAddMetaMarkExt(supported map[string][]string) *lwtAddMetadataMarkExt {
const lwtOptMetaBitMaskKey = "LWT_OPTIMIZATION_META_BIT_MASK"
if v, found := supported[lwtAddMetadataMarkKey]; found {
for i := range v {
splitVal := strings.Split(v[i], "=")
if splitVal[0] == lwtOptMetaBitMaskKey {
var (
err error
bitMask int
)
if bitMask, err = strconv.Atoi(splitVal[1]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", lwtOptMetaBitMaskKey, splitVal[1], err)
return nil
}
}
return &lwtAddMetadataMarkExt{
lwtOptMetaBitMask: bitMask,
}
}
}
}
return nil
}
func (ext *lwtAddMetadataMarkExt) serialize() map[string]string {
return map[string]string{
lwtAddMetadataMarkKey: fmt.Sprintf("LWT_OPTIMIZATION_META_BIT_MASK=%d", ext.lwtOptMetaBitMask),
}
}
func (ext *lwtAddMetadataMarkExt) name() string {
return lwtAddMetadataMarkKey
}
func parseSupported(supported map[string][]string) scyllaSupported {
const (
scyllaShard = "SCYLLA_SHARD"
scyllaNrShards = "SCYLLA_NR_SHARDS"
scyllaPartitioner = "SCYLLA_PARTITIONER"
scyllaShardingAlgorithm = "SCYLLA_SHARDING_ALGORITHM"
scyllaShardingIgnoreMSB = "SCYLLA_SHARDING_IGNORE_MSB"
scyllaShardAwarePort = "SCYLLA_SHARD_AWARE_PORT"
scyllaShardAwarePortSSL = "SCYLLA_SHARD_AWARE_PORT_SSL"
)
var (
si scyllaSupported
err error
)
if s, ok := supported[scyllaShard]; ok {
if si.shard, err = strconv.Atoi(s[0]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShard, s, err)
}
}
}
if s, ok := supported[scyllaNrShards]; ok {
if si.nrShards, err = strconv.Atoi(s[0]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaNrShards, s, err)
}
}
}
if s, ok := supported[scyllaShardingIgnoreMSB]; ok {
if si.msbIgnore, err = strconv.ParseUint(s[0], 10, 64); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardingIgnoreMSB, s, err)
}
}
}
if s, ok := supported[scyllaPartitioner]; ok {
si.partitioner = s[0]
}
if s, ok := supported[scyllaShardingAlgorithm]; ok {
si.shardingAlgorithm = s[0]
}
if s, ok := supported[scyllaShardAwarePort]; ok {
if shardAwarePort, err := strconv.ParseUint(s[0], 10, 16); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePort, s, err)
}
} else {
si.shardAwarePort = uint16(shardAwarePort)
}
}
if s, ok := supported[scyllaShardAwarePortSSL]; ok {
if shardAwarePortSSL, err := strconv.ParseUint(s[0], 10, 16); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePortSSL, s, err)
}
} else {
si.shardAwarePortSSL = uint16(shardAwarePortSSL)
}
}
if si.partitioner != "org.apache.cassandra.dht.Murmur3Partitioner" || si.shardingAlgorithm != "biased-token-round-robin" || si.nrShards == 0 || si.msbIgnore == 0 {
if gocqlDebug {
Logger.Printf("scylla: unsupported sharding configuration, partitioner=%s, algorithm=%s, no_shards=%d, msb_ignore=%d",
si.partitioner, si.shardingAlgorithm, si.nrShards, si.msbIgnore)
}
return scyllaSupported{}
}
return si
}
func parseCQLProtocolExtensions(supported map[string][]string) []cqlProtocolExtension {
exts := []cqlProtocolExtension{}
lwtExt := newLwtAddMetaMarkExt(supported)
if lwtExt != nil {
exts = append(exts, lwtExt)
}
rateLimitExt := newRateLimitExt(supported)
if rateLimitExt != nil {
exts = append(exts, rateLimitExt)
}
tabletsExt := newTabletsRoutingV1Ext(supported)
if tabletsExt != nil {
exts = append(exts, tabletsExt)
}
return exts
}
// isScyllaConn checks if conn is suitable for scyllaConnPicker.
func (conn *Conn) isScyllaConn() bool {
return conn.getScyllaSupported().nrShards != 0
}
// scyllaConnPicker is a specialised ConnPicker that selects connections based
// on token trying to get connection to a shard containing the given token.
// A list of excess connections is maintained to allow for lazy closing of
// connections to already opened shards. Keeping excess connections open helps
// reaching equilibrium faster since the likelihood of hitting the same shard
// decreases with the number of connections to the shard.
//
// scyllaConnPicker keeps track of the details about the shard-aware port.
// When used as a Dialer, it connects to the shard-aware port instead of the
// regular port (if the node supports it). For each subsequent connection
// it tries to make, the shard that it aims to connect to is chosen
// in a round-robin fashion.
type scyllaConnPicker struct {
address string
hostId string
shardAwareAddress string
conns []*Conn
excessConns []*Conn
nrConns int
nrShards int
msbIgnore uint64
pos uint64
lastAttemptedShard int
shardAwarePortDisabled bool
// Used to disable new connections to the shard-aware port temporarily
disableShardAwarePortUntil *atomic.Value
}
func newScyllaConnPicker(conn *Conn) *scyllaConnPicker {
addr := conn.Address()
hostId := conn.host.hostId
if conn.scyllaSupported.nrShards == 0 {
panic(fmt.Sprintf("scylla: %s not a sharded connection", addr))
}
if gocqlDebug {
Logger.Printf("scylla: %s new conn picker sharding options %+v", addr, conn.scyllaSupported)
}
var shardAwarePort uint16
if conn.session.connCfg.tlsConfig != nil {
shardAwarePort = conn.scyllaSupported.shardAwarePortSSL
} else {
shardAwarePort = conn.scyllaSupported.shardAwarePort
}
var shardAwareAddress string
if shardAwarePort != 0 {
tIP, tPort := conn.session.cfg.translateAddressPort(conn.host.UntranslatedConnectAddress(), int(shardAwarePort))
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
}
return &scyllaConnPicker{
address: addr,
hostId: hostId,
shardAwareAddress: shardAwareAddress,
nrShards: conn.scyllaSupported.nrShards,
msbIgnore: conn.scyllaSupported.msbIgnore,
lastAttemptedShard: 0,
shardAwarePortDisabled: conn.session.cfg.DisableShardAwarePort,
disableShardAwarePortUntil: new(atomic.Value),
}
}
func (p *scyllaConnPicker) Pick(t Token, qry ExecutableQuery) *Conn {
if len(p.conns) == 0 {
return nil
}
if t == nil {
return p.leastBusyConn()
}
mmt, ok := t.(int64Token)
// double check if that's murmur3 token
if !ok {
return nil
}
idx := -1
for _, conn := range p.conns {
if conn == nil {
continue
}
if qry != nil && conn.isTabletSupported() {
tablets := conn.session.getTablets()
// Search for tablets with Keyspace and Table from the Query
l, r := tablets.findTablets(qry.Keyspace(), qry.Table())
if l != -1 {
tablet := tablets.findTabletForToken(mmt, l, r)
for _, replica := range tablet.replicas {
if replica.hostId.String() == p.hostId {
idx = replica.shardId
}
}
}
}
break
}
if idx == -1 {
idx = p.shardOf(mmt)
}
if c := p.conns[idx]; c != nil {
// We have this shard's connection
// so let's give it to the caller.
// But only if it's not loaded too much and load is well distributed.
if qry != nil && qry.IsLWT() {
return c
}
return p.maybeReplaceWithLessBusyConnection(c)
}
return p.leastBusyConn()
}
func (p *scyllaConnPicker) maybeReplaceWithLessBusyConnection(c *Conn) *Conn {
if !isHeavyLoaded(c) {
return c
}
alternative := p.leastBusyConn()
if alternative == nil || alternative.AvailableStreams()*120 > c.AvailableStreams()*100 {
return c
} else {
return alternative
}
}
func isHeavyLoaded(c *Conn) bool {
return c.streams.NumStreams/2 > c.AvailableStreams()
}
func (p *scyllaConnPicker) leastBusyConn() *Conn {
var (
leastBusyConn *Conn
streamsAvailable int
)
idx := int(atomic.AddUint64(&p.pos, 1))
// find the conn which has the most available streams, this is racy
for i := range p.conns {
if conn := p.conns[(idx+i)%len(p.conns)]; conn != nil {
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
}
return leastBusyConn
}
func (p *scyllaConnPicker) shardOf(token int64Token) int {
shards := uint64(p.nrShards)
z := uint64(token+math.MinInt64) << p.msbIgnore
lo := z & 0xffffffff
hi := (z >> 32) & 0xffffffff
mul1 := lo * shards
mul2 := hi * shards
sum := (mul1 >> 32) + mul2
return int(sum >> 32)
}
func (p *scyllaConnPicker) Put(conn *Conn) {
var (
nrShards = conn.scyllaSupported.nrShards
shard = conn.scyllaSupported.shard
)
if nrShards == 0 {
panic(fmt.Sprintf("scylla: %s not a sharded connection", p.address))
}
if nrShards != len(p.conns) {
if nrShards != p.nrShards {
panic(fmt.Sprintf("scylla: %s invalid number of shards", p.address))
}
conns := p.conns
p.conns = make([]*Conn, nrShards, nrShards)
copy(p.conns, conns)
}
if c := p.conns[shard]; c != nil {
if conn.addr == p.shardAwareAddress {
// A connection made to the shard-aware port resulted in duplicate
// connection to the same shard being made. Because this is never
// intentional, it suggests that a NAT or AddressTranslator
// changes the source port along the way, therefore we can't trust
// the shard-aware port to return connection to the shard
// that we requested. Fall back to non-shard-aware port for some time.
Logger.Printf(
"scylla: %s connection to shard-aware address %s resulted in wrong shard being assigned; please check that you are not behind a NAT or AddressTranslater which changes source ports; falling back to non-shard-aware port for %v",
p.address,
p.shardAwareAddress,
scyllaShardAwarePortFallbackDuration,
)
until := time.Now().Add(scyllaShardAwarePortFallbackDuration)
p.disableShardAwarePortUntil.Store(until)
// Connections to shard-aware port do not influence how shards
// are chosen for the non-shard-aware port, therefore it can be
// closed immediately
closeConns(conn)
} else {
p.excessConns = append(p.excessConns, conn)
if gocqlDebug {
Logger.Printf("scylla: %s put shard %d excess connection total: %d missing: %d excess: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns, len(p.excessConns))
}
}
} else {
p.conns[shard] = conn
p.nrConns++
if gocqlDebug {
Logger.Printf("scylla: %s put shard %d connection total: %d missing: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns)
}
}
if p.shouldCloseExcessConns() {
p.closeExcessConns()
}
}
func (p *scyllaConnPicker) shouldCloseExcessConns() bool {
const maxExcessConnsFactor = 10
if p.nrConns >= p.nrShards {
return true
}
return len(p.excessConns) > maxExcessConnsFactor*p.nrShards
}
func (p *scyllaConnPicker) Remove(conn *Conn) {
shard := conn.scyllaSupported.shard
if conn.scyllaSupported.nrShards == 0 {
// It is possible for Remove to be called before the connection is added to the pool.
// Ignoring these connections here is safe.
if gocqlDebug {
Logger.Printf("scylla: %s has unknown sharding state, ignoring it", p.address)
}
return
}
if gocqlDebug {
Logger.Printf("scylla: %s remove shard %d connection", p.address, shard)
}
if p.conns[shard] != nil {
p.conns[shard] = nil
p.nrConns--
}
}
func (p *scyllaConnPicker) InFlight() int {
result := 0
for _, conn := range p.conns {
if conn != nil {
result = result + (conn.streams.InUse())
}
}
return result
}
func (p *scyllaConnPicker) Size() (int, int) {
return p.nrConns, p.nrShards - p.nrConns
}
func (p *scyllaConnPicker) Close() {
p.closeConns()
p.closeExcessConns()
}
func (p *scyllaConnPicker) closeConns() {
if len(p.conns) == 0 {
if gocqlDebug {
Logger.Printf("scylla: %s no connections to close", p.address)
}
return
}
conns := p.conns
p.conns = nil
p.nrConns = 0
if gocqlDebug {
Logger.Printf("scylla: %s closing %d connections", p.address, len(conns))
}
go closeConns(conns...)
}
func (p *scyllaConnPicker) closeExcessConns() {
if len(p.excessConns) == 0 {
if gocqlDebug {
Logger.Printf("scylla: %s no excess connections to close", p.address)
}
return
}
conns := p.excessConns
p.excessConns = nil
if gocqlDebug {
Logger.Printf("scylla: %s closing %d excess connections", p.address, len(conns))
}
go closeConns(conns...)
}
// Closing must be done outside of hostConnPool lock. If holding a lock
// a deadlock can occur when closing one of the connections returns error on close.
// See scylladb/gocql#53.
func closeConns(conns ...*Conn) {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}
// NextShard returns the shardID to connect to.
// nrShard specifies how many shards the host has.
// If nrShards is zero, the caller shouldn't use shard-aware port.
func (p *scyllaConnPicker) NextShard() (shardID, nrShards int) {
if p.shardAwarePortDisabled {
return 0, 0
}
disableUntil, _ := p.disableShardAwarePortUntil.Load().(time.Time)
if time.Now().Before(disableUntil) {
// There is suspicion that the shard-aware-port is not reachable
// or misconfigured, fall back to the non-shard-aware port
return 0, 0
}
// Find the shard without a connection
// It's important to start counting from 1 here because we want
// to consider the next shard after the previously attempted one
for i := 1; i <= p.nrShards; i++ {
shardID := (p.lastAttemptedShard + i) % p.nrShards
if p.conns == nil || p.conns[shardID] == nil {
p.lastAttemptedShard = shardID
return shardID, p.nrShards
}
}
// We did not find an unallocated shard
// We will dial the non-shard-aware port
return 0, 0
}
// ShardDialer is like HostDialer but is shard-aware.
// If the driver wants to connect to a specific shard, it will call DialShard,
// otherwise it will call DialHost.
type ShardDialer interface {
HostDialer
// DialShard establishes a connection to the specified shard ID out of nrShards.
// The returned connection must be directly usable for CQL protocol,
// specifically DialShard is responsible also for setting up the TLS session if needed.
DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error)
}
// A dialer which dials a particular shard
type scyllaDialer struct {
dialer Dialer
logger StdLogger
tlsConfig *tls.Config
cfg *ClusterConfig
}
const scyllaShardAwarePortFallbackDuration time.Duration = 5 * time.Minute
func (sd *scyllaDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
addr := host.HostnameAndPort()
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
}
func (sd *scyllaDialer) DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
iter := newScyllaPortIterator(shardID, nrShards)
addr := host.HostnameAndPort()
var shardAwarePort uint16
if sd.tlsConfig != nil {
shardAwarePort = host.ScyllaShardAwarePortTLS()
} else {
shardAwarePort = host.ScyllaShardAwarePort()
}
var shardAwareAddress string
if shardAwarePort != 0 {
tIP, tPort := sd.cfg.translateAddressPort(host.UntranslatedConnectAddress(), int(shardAwarePort))
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
}
if gocqlDebug {
sd.logger.Printf("scylla: connecting to shard %d", shardID)
}
conn, err := sd.dialShardAware(ctx, addr, shardAwareAddress, iter)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
}
func (sd *scyllaDialer) dialShardAware(ctx context.Context, addr, shardAwareAddr string, iter *scyllaPortIterator) (net.Conn, error) {
for {
port, ok := iter.Next()
if !ok {
// We exhausted ports to connect from. Try the non-shard-aware port.
return sd.dialer.DialContext(ctx, "tcp", addr)
}
ctxWithPort := context.WithValue(ctx, scyllaSourcePortCtx{}, port)
conn, err := sd.dialer.DialContext(ctxWithPort, "tcp", shardAwareAddr)
if isLocalAddrInUseErr(err) {
// This indicates that the source port is already in use
// We can immediately retry with another source port for this shard
continue
} else if err != nil {
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
if err == nil {
// We failed to connect to the shard-aware port, but succeeded
// in connecting to the non-shard-aware port. This might
// indicate that the shard-aware port is just not reachable,
// but we may also be unlucky and the node became reachable
// just after we tried the first connection.
// We can't avoid false positives here, so I'm putting it
// behind a debug flag.
if gocqlDebug {
sd.logger.Printf(
"scylla: %s couldn't connect to shard-aware address while the non-shard-aware address %s is available; this might be an issue with ",
addr,
shardAwareAddr,
)
}
}
return conn, err
}
return conn, err
}
}
// ErrScyllaSourcePortAlreadyInUse An error value which can returned from
// a custom dialer implementation to indicate that the requested source port
// to dial from is already in use
var ErrScyllaSourcePortAlreadyInUse = errors.New("scylla: source port is already in use")
func isLocalAddrInUseErr(err error) bool {
return errors.Is(err, syscall.EADDRINUSE) || errors.Is(err, ErrScyllaSourcePortAlreadyInUse)
}
// ScyllaShardAwareDialer wraps a net.Dialer, but uses a source port specified by gocql when connecting.
//
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
// a new connection by connecting from a specific source port. If you are using your own net.Dialer
// in ClusterConfig, you can use ScyllaShardAwareDialer to "upgrade" it so that it connects
// from the source port chosen by gocql.
//
// Please note that ScyllaShardAwareDialer overwrites the LocalAddr field in order to choose
// the right source port for connection.
type ScyllaShardAwareDialer struct {
net.Dialer
}
func (d *ScyllaShardAwareDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
sourcePort := ScyllaGetSourcePort(ctx)
if sourcePort == 0 {
return d.Dialer.DialContext(ctx, network, addr)
}
dialerWithLocalAddr := d.Dialer
dialerWithLocalAddr.LocalAddr, err = net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
if err != nil {
return nil, err
}
return dialerWithLocalAddr.DialContext(ctx, network, addr)
}
type scyllaPortIterator struct {
currentPort int
shardCount int
}
const (
scyllaPortBasedBalancingMin = 0x8000
scyllaPortBasedBalancingMax = 0xFFFF
)
func newScyllaPortIterator(shardID, shardCount int) *scyllaPortIterator {
if shardCount == 0 {
panic("shardCount cannot be 0")
}
// Find the smallest port p such that p >= min and p % shardCount == shardID
port := scyllaPortBasedBalancingMin - scyllaShardForSourcePort(scyllaPortBasedBalancingMin, shardCount) + shardID
if port < scyllaPortBasedBalancingMin {
port += shardCount
}
return &scyllaPortIterator{
currentPort: port,
shardCount: shardCount,
}
}
func (spi *scyllaPortIterator) Next() (uint16, bool) {
if spi == nil {
return 0, false
}
p := spi.currentPort
if p > scyllaPortBasedBalancingMax {
return 0, false
}
spi.currentPort += spi.shardCount
return uint16(p), true
}
func scyllaShardForSourcePort(sourcePort uint16, shardCount int) int {
return int(sourcePort) % shardCount
}
type scyllaSourcePortCtx struct{}
// ScyllaGetSourcePort returns the source port that should be used when connecting to a node.
//
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
// a new connection at the shard-aware port by connecting from a specific source port. Therefore,
// if you are using a custom Dialer and your nodes expose shard-aware ports, your dialer should
// use the source port specified by gocql.
//
// If this function returns 0, then your dialer can use any source port.
//
// If you aren't using a custom dialer, gocql will use a default one which uses appropriate source port.
// If you are using net.Dialer, consider wrapping it in a gocql.ScyllaShardAwareDialer.
func ScyllaGetSourcePort(ctx context.Context) uint16 {
sourcePort, _ := ctx.Value(scyllaSourcePortCtx{}).(uint16)
return sourcePort
}
// Returns a partitioner specific to the table, or "nil"
// if the cluster-global partitioner should be used
func scyllaGetTablePartitioner(session *Session, keyspaceName, tableName string) (Partitioner, error) {
isCdc, err := scyllaIsCdcTable(session, keyspaceName, tableName)
if err != nil {
return nil, err
}
if isCdc {
return scyllaCDCPartitioner{}, nil
}
return nil, nil
}

93
vendor/github.com/gocql/gocql/scylla_cdc.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
package gocql
import (
"encoding/binary"
"math"
"strings"
)
// cdc partitioner
const (
scyllaCDCPartitionerName = "CDCPartitioner"
scyllaCDCPartitionerFullName = "com.scylladb.dht.CDCPartitioner"
scyllaCDCPartitionKeyLength = 16
scyllaCDCVersionMask = 0x0F
scyllaCDCMinSupportedVersion = 1
scyllaCDCMaxSupportedVersion = 1
scyllaCDCMinToken = int64Token(math.MinInt64)
scyllaCDCLogTableNameSuffix = "_scylla_cdc_log"
scyllaCDCExtensionName = "cdc"
)
type scyllaCDCPartitioner struct{}
var _ Partitioner = scyllaCDCPartitioner{}
func (p scyllaCDCPartitioner) Name() string {
return scyllaCDCPartitionerName
}
func (p scyllaCDCPartitioner) Hash(partitionKey []byte) Token {
if len(partitionKey) < 8 {
// The key is too short to extract any sensible token,
// so return the min token instead
if gocqlDebug {
Logger.Printf("scylla: cdc partition key too short: %d < 8", len(partitionKey))
}
return scyllaCDCMinToken
}
upperQword := binary.BigEndian.Uint64(partitionKey[0:])
if gocqlDebug {
// In debug mode, do some more checks
if len(partitionKey) != scyllaCDCPartitionKeyLength {
// The token has unrecognized format, but the first quadword
// should be the token value that we want
Logger.Printf("scylla: wrong size of cdc partition key: %d", len(partitionKey))
}
lowerQword := binary.BigEndian.Uint64(partitionKey[8:])
version := lowerQword & scyllaCDCVersionMask
if version < scyllaCDCMinSupportedVersion || version > scyllaCDCMaxSupportedVersion {
// We don't support this version yet,
// the token may be wrong
Logger.Printf(
"scylla: unsupported version: %d is not in range [%d, %d]",
version,
scyllaCDCMinSupportedVersion,
scyllaCDCMaxSupportedVersion,
)
}
}
return int64Token(upperQword)
}
func (p scyllaCDCPartitioner) ParseString(str string) Token {
return parseInt64Token(str)
}
func scyllaIsCdcTable(session *Session, keyspaceName, tableName string) (bool, error) {
if !strings.HasSuffix(tableName, scyllaCDCLogTableNameSuffix) {
// Not a CDC table, use the default partitioner
return false, nil
}
// Get the table metadata to see if it has the cdc partitioner set
keyspaceMeta, err := session.KeyspaceMetadata(keyspaceName)
if err != nil {
return false, err
}
tableMeta, ok := keyspaceMeta.Tables[tableName]
if !ok {
return false, ErrNoMetadata
}
return tableMeta.Options.Partitioner == scyllaCDCPartitionerFullName, nil
}

View File

@@ -0,0 +1,28 @@
package ascii
import (
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case string:
return EncString(v)
case *string:
return EncStringR(v)
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(rv)
}
return EncReflectR(rv)
}
}

View File

@@ -0,0 +1,61 @@
package ascii
import (
"fmt"
"reflect"
)
func EncString(v string) ([]byte, error) {
return encString(v), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return encString(*v), nil
}
func EncBytes(v []byte) ([]byte, error) {
return v, nil
}
func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return *v, nil
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.String:
return encString(v.String()), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return EncBytes(v.Bytes())
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encString(v string) []byte {
if v == "" {
return make([]byte, 0)
}
return []byte(v)
}

View File

@@ -0,0 +1,33 @@
package ascii
import (
"fmt"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
case *[]byte:
return DecBytes(data, v)
case **[]byte:
return DecBytesR(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,166 @@
package ascii
import (
"fmt"
"reflect"
)
func errInvalidData(p []byte) error {
for i := range p {
if p[i] > 127 {
return fmt.Errorf("failed to unmarshal ascii: invalid charester %s", string(p[i]))
}
}
return nil
}
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal ascii: can not unmarshal into nil reference(%T)(%[1]v)", v)
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
*v = decString(p)
return errInvalidData(p)
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
*v = decStringR(p)
return errInvalidData(p)
}
func DecBytes(p []byte, v *[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytes(p)
return errInvalidData(p)
}
func DecBytesR(p []byte, v **[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytesR(p)
return errInvalidData(p)
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch v = v.Elem(); v.Kind() {
case reflect.String:
v.SetString(decString(p))
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
v.SetBytes(decBytes(p))
default:
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return errInvalidData(p)
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch ev := v.Type().Elem().Elem(); ev.Kind() {
case reflect.String:
return decReflectStringR(p, v)
case reflect.Slice:
if ev.Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return decReflectBytesR(p, v)
default:
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectStringR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
v.Elem().Set(reflect.New(v.Type().Elem().Elem()))
}
return nil
}
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(string(p))
v.Elem().Set(val)
return errInvalidData(p)
}
func decReflectBytesR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(make([]byte, 0))
v.Elem().Set(val)
}
return nil
}
tmp := make([]byte, len(p))
copy(tmp, p)
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(tmp)
v.Elem().Set(val)
return errInvalidData(p)
}
func decString(p []byte) string {
if len(p) == 0 {
return ""
}
return string(p)
}
func decStringR(p []byte) *string {
if len(p) == 0 {
if p == nil {
return nil
}
return new(string)
}
tmp := string(p)
return &tmp
}
func decBytes(p []byte) []byte {
if len(p) == 0 {
if p == nil {
return nil
}
return make([]byte, 0)
}
tmp := make([]byte, len(p))
copy(tmp, p)
return tmp
}
func decBytesR(p []byte) *[]byte {
if len(p) == 0 {
if p == nil {
return nil
}
tmp := make([]byte, 0)
return &tmp
}
tmp := make([]byte, len(p))
copy(tmp, p)
return &tmp
}

View File

@@ -0,0 +1,74 @@
package bigint
import (
"math/big"
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int16:
return EncInt16(v)
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)
case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)
case big.Int:
return EncBigInt(v)
case string:
return EncString(v)
case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)
case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)
case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,206 @@
package bigint
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var (
maxBigInt = big.NewInt(math.MaxInt64)
minBigInt = big.NewInt(math.MinInt64)
)
func EncInt8(v int8) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, 255, byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, 0, byte(v)}, nil
}
func EncInt8R(v *int8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt8(*v)
}
func EncInt16(v int16) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncInt16R(v *int16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt16(*v)
}
func EncInt32(v int32) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt32(*v)
}
func EncInt64(v int64) ([]byte, error) {
return encInt64(v), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncInt(v int) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncIntR(v *int) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt(*v)
}
func EncUint8(v uint8) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, 0, v}, nil
}
func EncUint8R(v *uint8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint8(*v)
}
func EncUint16(v uint16) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncUint16R(v *uint16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint16(*v)
}
func EncUint32(v uint32) ([]byte, error) {
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint32(*v)
}
func EncUint64(v uint64) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint64R(v *uint64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint64(*v)
}
func EncUint(v uint) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUintR(v *uint) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint(*v)
}
func EncBigInt(v big.Int) ([]byte, error) {
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncBigIntR(v *big.Int) ([]byte, error) {
if v == nil {
return nil, nil
}
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v, err)
}
return encInt64(n), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return EncInt64(v.Int())
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return EncUint64(v.Uint())
case reflect.String:
val := v.String()
if val == "" {
return nil, nil
}
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v.Interface(), err)
}
return encInt64(n), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

View File

@@ -0,0 +1,81 @@
package bigint
import (
"fmt"
"math/big"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int8:
return DecInt8(data, v)
case *int16:
return DecInt16(data, v)
case *int32:
return DecInt32(data, v)
case *int64:
return DecInt64(data, v)
case *int:
return DecInt(data, v)
case *uint8:
return DecUint8(data, v)
case *uint16:
return DecUint16(data, v)
case *uint32:
return DecUint32(data, v)
case *uint64:
return DecUint64(data, v)
case *uint:
return DecUint(data, v)
case *big.Int:
return DecBigInt(data, v)
case *string:
return DecString(data, v)
case **int8:
return DecInt8R(data, v)
case **int16:
return DecInt16R(data, v)
case **int32:
return DecInt32R(data, v)
case **int64:
return DecInt64R(data, v)
case **int:
return DecIntR(data, v)
case **uint8:
return DecUint8R(data, v)
case **uint16:
return DecUint16R(data, v)
case **uint32:
return DecUint32R(data, v)
case **uint64:
return DecUint64R(data, v)
case **uint:
return DecUintR(data, v)
case **big.Int:
return DecBigIntR(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyInt int) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal bigint: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

Some files were not shown because too many files have changed in this diff Show More