implement Scylla database

This commit is contained in:
2025-05-17 21:45:18 -04:00
parent 252b49ae6a
commit f8a550883d
199 changed files with 71243 additions and 424 deletions

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"git.dubyatp.xyz/chat-api-server/db"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/docgen"
@@ -14,6 +15,10 @@ import (
var routes = flag.Bool("routes", false, "Generate API route documentation")
func Start() {
db.InitScyllaDB()
defer db.CloseScyllaDB()
flag.Parse()
r := chi.NewRouter()

331
api/db.go
View File

@@ -3,275 +3,168 @@ package api
import (
"errors"
"fmt"
"time"
"git.dubyatp.xyz/chat-api-server/fake_db"
"git.dubyatp.xyz/chat-api-server/db"
"github.com/gocql/gocql"
)
func dbGetUser(id string) (*User, error) {
data := fake_db.ExecDB("users")
if data == nil {
return nil, errors.New("failed to load users database")
query := `SELECT id, name, password FROM users WHERE id = ?`
var user User
err := db.Session.Query(query, id).Scan(&user.ID, &user.Name, &user.Password)
if err == gocql.ErrNotFound {
return nil, errors.New("User not found")
} else if err != nil {
return nil, fmt.Errorf("failed to query user: %v", err)
}
users := data["users"].([]interface{})
for _, u := range users {
user := u.(map[string]interface{})
if user["ID"].(string) == id {
return &User{
ID: user["ID"].(string),
Name: user["Name"].(string),
Password: user["Password"].(string),
}, nil
}
}
return nil, errors.New("User not found")
return &user, nil
}
func dbGetUserByName(username string) (*User, error) {
data := fake_db.ExecDB("users")
if data == nil {
return nil, errors.New("failed to load users database")
query := `SELECT id, name, password FROM users WHERE name = ? ALLOW FILTERING`
var user User
err := db.Session.Query(query, username).Scan(&user.ID, &user.Name, &user.Password)
if err == gocql.ErrNotFound {
return nil, errors.New("User not found")
} else if err != nil {
return nil, fmt.Errorf("failed to query user: %v", err)
}
users := data["users"].([]interface{})
for _, u := range users {
user := u.(map[string]interface{})
if user["Name"].(string) == username {
return &User{
ID: user["ID"].(string),
Name: user["Name"].(string),
Password: user["Password"].(string),
}, nil
}
}
return nil, errors.New("User not found")
return &user, nil
}
func dbGetAllUsers() ([]*User, error) {
data := fake_db.ExecDB("users")
if data == nil {
return nil, errors.New("failed to load users database")
query := `SELECT id, name, password FROM users`
iter := db.Session.Query(query).Iter()
defer iter.Close()
var users []*User
for {
user := &User{}
if !iter.Scan(&user.ID, &user.Name, &user.Password) {
break
}
users = append(users, user)
}
users := data["users"].([]interface{})
var result []*User
for _, u := range users {
user := u.(map[string]interface{})
result = append(result, &User{
ID: user["ID"].(string),
Name: user["Name"].(string),
})
if err := iter.Close(); err != nil {
return nil, fmt.Errorf("failed to iterate users: %v", err)
}
if len(result) == 0 {
if len(users) == 0 {
return nil, errors.New("no users found")
}
return result, nil
return users, nil
}
func dbGetMessage(id string) (*Message, error) {
data := fake_db.ExecDB("messages")
if data == nil {
return nil, errors.New("failed to load messages database")
query := `SELECT id, body, edited, timestamp, userid FROM messages WHERE id = ?`
var message Message
err := db.Session.Query(query, id).Scan(
&message.ID,
&message.Body,
&message.Edited,
&message.Timestamp,
&message.UserID)
if err == gocql.ErrNotFound {
return nil, errors.New("Message not found")
} else if err != nil {
return nil, fmt.Errorf("failed to query message: %v", err)
}
messages := data["messages"].([]interface{})
for _, m := range messages {
message := m.(map[string]interface{})
if message["ID"].(string) == id {
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
}
editedStr, ok := message["Edited"].(string)
var edited time.Time
if ok && editedStr != "" {
var err error
edited, err = time.Parse(time.RFC3339, editedStr)
if err != nil {
return nil, fmt.Errorf("failed to parse edited timestamp: %v", err)
}
}
return &Message{
ID: message["ID"].(string),
UserID: message["UserID"].(string),
Body: message["Body"].(string),
Timestamp: timestamp,
Edited: edited,
}, nil
}
}
return nil, errors.New("Message not found")
return &message, nil
}
func dbGetAllMessages() ([]*Message, error) {
data := fake_db.ExecDB("messages")
//println(data)
if data == nil {
return nil, errors.New("failed to load messages database")
query := `SELECT id, body, edited, timestamp, userid FROM messages`
iter := db.Session.Query(query).Iter()
defer iter.Close()
var messages []*Message
for {
message := &Message{}
if !iter.Scan(
&message.ID,
&message.Body,
&message.Edited,
&message.Timestamp,
&message.UserID) {
break
}
messages = append(messages, message)
}
messages := data["messages"].([]interface{})
var result []*Message
for _, m := range messages {
message := m.(map[string]interface{})
timestamp, err := time.Parse(time.RFC3339, message["Timestamp"].(string))
if err != nil {
return nil, fmt.Errorf("failed to parse timestamp: %v", err)
}
editedStr, ok := message["Edited"].(string)
var edited time.Time
if ok && editedStr != "" {
var err error
edited, err = time.Parse(time.RFC3339, editedStr)
if err != nil {
return nil, fmt.Errorf("failed to parse edited timestamp: %v", err)
}
}
result = append(result, &Message{
ID: message["ID"].(string),
UserID: message["UserID"].(string),
Body: message["Body"].(string),
Timestamp: timestamp,
Edited: edited,
})
if err := iter.Close(); err != nil {
return nil, fmt.Errorf("failed to iterate messages: %v", err)
}
if len(result) == 0 {
if len(messages) == 0 {
return nil, errors.New("no messages found")
}
return result, nil
return messages, nil
}
func dbAddUser(user *User) error {
currentData := fake_db.ExecDB("users")
if currentData == nil {
return fmt.Errorf("error reading users database")
query := `INSERT INTO users (id, name, password) VALUES (?, ?, ?)`
err := db.Session.Query(query, user.ID, user.Name, user.Password).Exec()
if err != nil {
return fmt.Errorf("failed to add user: %v", err)
}
users, ok := currentData["users"].([]interface{})
if !ok {
return fmt.Errorf("users data is in an unexpected format")
}
dbUser := map[string]interface{}{
"ID": user.ID,
"Name": user.Name,
"Password": user.Password,
}
users = append(users, dbUser)
return fake_db.WriteDB("users", users)
return nil
}
func dbAddMessage(message *Message) error {
currentData := fake_db.ExecDB("messages")
if currentData == nil {
return fmt.Errorf("error reading messages database")
query := `INSERT INTO messages (id, body, edited, timestamp, userid)
VALUES (?, ?, ?, ?, ?)`
err := db.Session.Query(query,
message.ID,
message.Body,
nil,
message.Timestamp,
message.UserID).Exec()
if err != nil {
return fmt.Errorf("failed to add message: %v", err)
}
messages, ok := currentData["messages"].([]interface{})
if !ok {
return fmt.Errorf("messages data is in an unexpected format")
}
var edited interface{}
if message.Edited.IsZero() {
edited = nil // Set to nil if Edited is the zero value
} else {
edited = message.Edited.Format(time.RFC3339)
}
dbMessage := map[string]interface{}{
"ID": message.ID,
"UserID": message.UserID, // JSON numbers are float64
"Body": message.Body,
"Timestamp": message.Timestamp.Format(time.RFC3339),
"Edited": edited,
}
messages = append(messages, dbMessage)
return fake_db.WriteDB("messages", messages)
return nil
}
func dbUpdateMessage(updatedMessage *Message) error {
currentData := fake_db.ExecDB("messages")
if currentData == nil {
return fmt.Errorf("error reading messages database")
var edited interface{}
if updatedMessage.Edited.IsZero() {
edited = nil
} else {
edited = updatedMessage.Edited
}
messages, ok := currentData["messages"].([]interface{})
if !ok {
return fmt.Errorf("messages data is in an unexpected format")
query := `UPDATE messages
SET body = ?, edited = ?, timestamp = ?
WHERE ID = ?`
err := db.Session.Query(query,
updatedMessage.Body,
edited,
updatedMessage.Timestamp,
updatedMessage.ID).Exec()
if err != nil {
return fmt.Errorf("failed to update message: %v", err)
}
var updatedMessages []interface{}
found := false
return nil
for _, m := range messages {
message, ok := m.(map[string]interface{})
if !ok {
continue
}
if messageID, ok := message["ID"].(string); ok && messageID == updatedMessage.ID {
found = true
var edited interface{}
if updatedMessage.Edited.IsZero() {
edited = nil // Set to nil if Edited is the zero value
} else {
edited = updatedMessage.Edited.Format(time.RFC3339)
}
message = map[string]interface{}{
"ID": updatedMessage.ID,
"UserID": updatedMessage.UserID,
"Body": updatedMessage.Body,
"Timestamp": updatedMessage.Timestamp.Format(time.RFC3339),
"Edited": edited,
}
}
updatedMessages = append(updatedMessages, message)
}
if !found {
return fmt.Errorf("message with ID %s not found", updatedMessage.ID)
}
return fake_db.WriteDB("messages", updatedMessages)
}
func dbDeleteMessage(id string) error {
currentData := fake_db.ExecDB("messages")
if currentData == nil {
return fmt.Errorf("error reading messages database")
query := `DELETE FROM messages WHERE ID = ?`
err := db.Session.Query(query, id).Exec()
if err != nil {
return fmt.Errorf("failed to delete message: %v", err)
}
messages, ok := currentData["messages"].([]interface{})
if !ok {
return fmt.Errorf("messages data is in an unexpected format")
}
var updatedMessages []interface{}
found := false
for _, m := range messages {
message, ok := m.(map[string]interface{})
if !ok {
continue
}
if messageID, ok := message["ID"].(string); ok && messageID == id {
found = true
continue
}
updatedMessages = append(updatedMessages, message)
}
if !found {
return fmt.Errorf("message with ID %s not found", id)
}
return fake_db.WriteDB("messages", updatedMessages)
return nil
}

View File

@@ -65,7 +65,8 @@ func EditMessage(w http.ResponseWriter, r *http.Request) {
}
message.Body = body
message.Edited = time.Now()
editedTime := time.Now()
message.Edited = &editedTime
err = dbUpdateMessage(message)
if err != nil {
@@ -85,7 +86,7 @@ func DeleteMessage(w http.ResponseWriter, r *http.Request) {
render.Render(w, r, ErrNotFound)
return
}
dbDeleteMessage(message.ID)
dbDeleteMessage(message.ID.String())
if err := render.Render(w, r, NewMessageResponse(message)); err != nil {
render.Render(w, r, ErrRender(err))
return
@@ -104,8 +105,8 @@ func ListMessages(w http.ResponseWriter, r *http.Request) {
}
}
func newMessageID() string {
return "msg_" + uuid.New().String()
func newMessageID() uuid.UUID {
return uuid.New()
}
func NewMessage(w http.ResponseWriter, r *http.Request) {
@@ -135,7 +136,6 @@ func NewMessage(w http.ResponseWriter, r *http.Request) {
UserID: user.ID,
Body: body,
Timestamp: time.Now(),
Edited: time.Time{},
}
err = dbAddMessage(&msg)
@@ -150,11 +150,11 @@ func NewMessage(w http.ResponseWriter, r *http.Request) {
type messageKey struct{}
type Message struct {
ID string `json:"id"`
UserID string `json:"user_id"`
Body string `json:"body"`
Timestamp time.Time `json:"timestamp"`
Edited time.Time `json:"edited"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Body string `json:"body"`
Timestamp time.Time `json:"timestamp"`
Edited *time.Time `json:"edited"`
}
type MessageRequest struct {
@@ -175,8 +175,8 @@ type MessageResponse struct {
func (m MessageResponse) MarshalJSON() ([]byte, error) {
type OrderedMessageResponse struct {
ID string `json:"id"`
UserID string `json:"user_id"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Body string `json:"body"`
Timestamp string `json:"timestamp"`
Edited *string `json:"edited,omitempty"` // Use a pointer to allow null values
@@ -185,7 +185,7 @@ func (m MessageResponse) MarshalJSON() ([]byte, error) {
}
var edited *string
if !m.Message.Edited.IsZero() { // Check if Edited is not the zero value
if m.Message.Edited != nil { // Check if Edited is not the zero value
editedStr := m.Message.Edited.Format(time.RFC3339)
edited = &editedStr
}

View File

@@ -10,7 +10,7 @@ func NewMessageResponse(message *Message) *MessageResponse {
resp := &MessageResponse{Message: message}
if resp.User == nil {
if user, _ := dbGetUser(resp.UserID); user != nil {
if user, _ := dbGetUser(resp.UserID.String()); user != nil {
resp.User = NewUserPayloadResponse(user)
}
}

View File

@@ -89,8 +89,8 @@ func ListUsers(w http.ResponseWriter, r *http.Request) {
}
}
func newUserID() string {
return "user_" + uuid.New().String()
func newUserID() uuid.UUID {
return uuid.New()
}
func NewUser(w http.ResponseWriter, r *http.Request) {
@@ -130,9 +130,9 @@ func NewUser(w http.ResponseWriter, r *http.Request) {
type userKey struct{}
type User struct {
ID string `json:"id"`
Name string `json:"name"`
Password string `json:"-"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Password string `json:"-"`
}
type UserPayload struct {

28
db/scylla.go Normal file
View File

@@ -0,0 +1,28 @@
package db
import (
"log"
"github.com/gocql/gocql"
)
var Session *gocql.Session
func InitScyllaDB() {
cluster := gocql.NewCluster("127.0.0.1") // Replace with your ScyllaDB cluster IPs
cluster.Keyspace = "chatservice" // Replace with your keyspace
cluster.Consistency = gocql.Quorum
session, err := cluster.CreateSession()
if err != nil {
log.Fatalf("Failed to connect to ScyllaDB: %v", err)
}
Session = session
log.Println("Connected to ScyllaDB")
}
func CloseScyllaDB() {
if Session != nil {
Session.Close()
}
}

View File

@@ -1,73 +0,0 @@
package fake_db
import (
"encoding/json"
"fmt"
"io"
"os"
)
func ExecDB(db_name string) map[string]interface{} {
var result map[string]interface{}
if db_name == "users" {
users_db, err := os.Open("./test_data/users.json")
if err != nil {
fmt.Println(err)
return nil
}
fmt.Println("Successfully opened Users DB")
defer users_db.Close()
byteValue, _ := io.ReadAll(users_db)
var users []interface{}
json.Unmarshal(byteValue, &users)
result = map[string]interface{}{"users": users}
} else if db_name == "messages" {
messages_db, err := os.Open("./test_data/messages.json")
if err != nil {
fmt.Println(err)
return nil
}
fmt.Println("Successfully opened Messages DB")
defer messages_db.Close()
byteValue, _ := io.ReadAll(messages_db)
var messages []interface{}
json.Unmarshal(byteValue, &messages)
result = map[string]interface{}{"messages": messages}
} else {
fmt.Println("Invalid DB name")
return nil
}
return result
}
func WriteDB(db_name string, data interface{}) error {
var filePath string
switch db_name {
case "users":
filePath = "./test_data/users.json"
case "messages":
filePath = "./test_data/messages.json"
default:
return fmt.Errorf("invalid database name: %s", db_name)
}
jsonData, err := json.MarshalIndent(data, "", " ")
if err != nil {
return fmt.Errorf("error marshaling data to JSON: %v", err)
}
err = os.WriteFile(filePath, jsonData, 0644)
if err != nil {
return fmt.Errorf("error writing to file: %v", err)
}
fmt.Printf("Successfully wrote to %s DB\n", db_name)
return nil
}

View File

@@ -1,7 +1,9 @@
{
description = "Unnamed Chat Server API";
inputs.nixpkgs.url = "nixpkgs/nixos-unstable";
inputs = {
nixpkgs.url = "nixpkgs/nixos-unstable";
};
outputs = { self, nixpkgs }:
let

9
go.mod
View File

@@ -8,10 +8,19 @@ require (
github.com/go-chi/chi/v5 v5.2.0
github.com/go-chi/docgen v1.3.0
github.com/go-chi/render v1.0.3
github.com/gocql/gocql v1.7.0
github.com/google/uuid v1.6.0
)
require (
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
github.com/klauspost/compress v1.17.9 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
)
require (
github.com/ajg/form v1.5.1 // indirect
golang.org/x/crypto v0.36.0
)
replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.5

35
go.sum
View File

@@ -1,5 +1,9 @@
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi/v5 v5.0.1/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/chi/v5 v5.2.0 h1:Aj1EtB0qR2Rdo2dG4O94RIU35w2lvQSj6BRA4+qwFL0=
github.com/go-chi/chi/v5 v5.2.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
@@ -8,9 +12,40 @@ github.com/go-chi/docgen v1.3.0/go.mod h1:G9W0G551cs2BFMSn/cnGwX+JBHEloAgo17MBhy
github.com/go-chi/render v1.0.1/go.mod h1:pq4Rr7HbnsdaeHagklXub+p6Wd16Af5l9koip1OvJns=
github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4=
github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/scylladb/gocql v1.14.5 h1:lyJKf0m/Vate+8MGiVeRhQNpLVVsL21gvp89zEZdltI=
github.com/scylladb/gocql v1.14.5/go.mod h1:1efi3H0Gr72WCR0W+i+d63FmwmJhDL/zfAC0gMJHVlM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8=

View File

@@ -1,100 +0,0 @@
[
{
"Body": "hello",
"Edited": null,
"ID": "1",
"Timestamp": "2024-12-25T05:00:40Z",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "world",
"Edited": null,
"ID": "2",
"Timestamp": "2024-12-25T05:00:43Z",
"UserID": "user_63dac6ad-f255-4af8-a057-4b064a982a84"
},
{
"Body": "abababa",
"Edited": null,
"ID": "3",
"Timestamp": "2024-12-25T05:01:20Z",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "bitch",
"Edited": null,
"ID": "4",
"Timestamp": "2024-12-25T05:05:55Z",
"UserID": "user_63dac6ad-f255-4af8-a057-4b064a982a84"
},
{
"Body": "NIBBA",
"Edited": null,
"ID": "5",
"Timestamp": "2025-03-24T14:48:28.249221047-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "nibby",
"Edited": null,
"ID": "6",
"Timestamp": "2025-03-24T14:49:03.246929039-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "aaaaababananana",
"Edited": null,
"ID": "msg_60f70a47-3be2-4315-869a-d6f151ec262a",
"Timestamp": "2025-03-24T15:01:07.14371835-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "ababa abbott",
"Edited": null,
"ID": "msg_94cbc26d-9098-4fa9-bd21-794516c2263d",
"Timestamp": "2025-03-24T20:34:57.198849367-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "AAAAAA",
"Edited": null,
"ID": "msg_ca8483db-e823-45c4-882c-fe0930610ba9",
"Timestamp": "2025-03-24T21:17:04.350827576-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "i am a femboiiiii",
"Edited": null,
"ID": "msg_fcdbb48a-4ea5-4fb3-b925-3a15eb7c291c",
"Timestamp": "2025-03-24T21:27:48.565290147-04:00",
"UserID": "user_63dac6ad-f255-4af8-a057-4b064a982a84"
},
{
"Body": "i love soap",
"Edited": "2025-03-27T14:49:14-04:00",
"ID": "msg_59851eb1-2e63-46c1-b496-55566c414e33",
"Timestamp": "2025-03-27T14:40:26-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "I'd just like to interject for a moment. What you're referring to as Linux, is in fact, GNU/Linux, or as I've recently taken to calling it, GNU plus Linux. Linux is not an operating system unto itself, but rather another free component of a fully functioning GNU system made useful by the GNU corelibs, shell utilities and vital system components comprising a full OS as defined by POSIX. Many computer users run a modified version of the GNU system every day, without realizing it. Through a peculiar turn of events, the version of GNU which is widely used today is often called “Linux,” and many of its users are not aware that it is basically the GNU system, developed by the GNU Project. There really is a Linux, and these people are using it, but it is just a part of the system they use.\n\nLinux is the kernel: the program in the system that allocates the machine's resources to the other programs that you run. The kernel is an essential part of an operating system, but useless by itself; it can only function in the context of a complete operating system. Linux is normally used in combination with the GNU operating system: the whole system is basically GNU with Linux added, or GNU/Linux. All the so-called “Linux” distributions are really distributions of GNU/Linux.",
"Edited": "2025-03-27T20:35:33-04:00",
"ID": "msg_d77f8e0f-5c23-4c10-984f-b07559e7c5ed",
"Timestamp": "2025-03-27T18:56:27-04:00",
"UserID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54"
},
{
"Body": "oh \n\n\nok",
"Edited": null,
"ID": "msg_8d0d8e24-2c1d-4337-afdb-06d1a121e486",
"Timestamp": "2025-03-27T18:57:52-04:00",
"UserID": "user_63dac6ad-f255-4af8-a057-4b064a982a84"
},
{
"Body": "we shall ATTACK at the edge of propaganda",
"Edited": null,
"ID": "msg_dc55edfd-e0f7-4923-b686-df90ad4bb108",
"Timestamp": "2025-03-27T19:00:17-04:00",
"UserID": "user_63dac6ad-f255-4af8-a057-4b064a982a84"
}
]

View File

@@ -1,12 +0,0 @@
[
{
"ID": "user_8d7cd2ed-0aa2-4810-a172-42dd58563a54",
"Name": "duby",
"Password": "$2a$10$fYKgHJRgR6hJl9VAAu4HPeeyTbDP3UCxiAxZMMKDL8A0ya0Sdg.pq"
},
{
"ID": "user_63dac6ad-f255-4af8-a057-4b064a982a84",
"Name": "astolfo",
"Password": "$2a$10$ryzbb6l/hkZH6wwtdLdbYew3R1ug4O3tdHi4581WQHui8JKSPFqSu"
}
]

5
vendor/github.com/gocql/gocql/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,5 @@
gocql-fuzz
fuzz-corpus
fuzz-work
gocql.test
.idea

148
vendor/github.com/gocql/gocql/AUTHORS generated vendored Normal file
View File

@@ -0,0 +1,148 @@
# This source file refers to The gocql Authors for copyright purposes.
Christoph Hack <christoph@tux21b.org>
Jonathan Rudenberg <jonathan@titanous.com>
Thorsten von Eicken <tve@rightscale.com>
Matt Robenolt <mattr@disqus.com>
Phillip Couto <phillip.couto@stemstudios.com>
Niklas Korz <korz.niklask@gmail.com>
Nimi Wariboko Jr <nimi@channelmeter.com>
Ghais Issa <ghais.issa@gmail.com>
Sasha Klizhentas <klizhentas@gmail.com>
Konstantin Cherkasov <k.cherkasoff@gmail.com>
Ben Hood <0x6e6562@gmail.com>
Pete Hopkins <phopkins@gmail.com>
Chris Bannister <c.bannister@gmail.com>
Maxim Bublis <b@codemonkey.ru>
Alex Zorin <git@zor.io>
Kasper Middelboe Petersen <me@phant.dk>
Harpreet Sawhney <harpreet.sawhney@gmail.com>
Charlie Andrews <charlieandrews.cwa@gmail.com>
Stanislavs Koikovs <stanislavs.koikovs@gmail.com>
Dan Forest <bonjour@dan.tf>
Miguel Serrano <miguelvps@gmail.com>
Stefan Radomski <gibheer@zero-knowledge.org>
Josh Wright <jshwright@gmail.com>
Jacob Rhoden <jacob.rhoden@gmail.com>
Ben Frye <benfrye@gmail.com>
Fred McCann <fred@sharpnoodles.com>
Dan Simmons <dan@simmons.io>
Muir Manders <muir@retailnext.net>
Sankar P <sankar.curiosity@gmail.com>
Julien Da Silva <julien.dasilva@gmail.com>
Dan Kennedy <daniel@firstcs.co.uk>
Nick Dhupia<nick.dhupia@gmail.com>
Yasuharu Goto <matope.ono@gmail.com>
Jeremy Schlatter <jeremy.schlatter@gmail.com>
Matthias Kadenbach <matthias.kadenbach@gmail.com>
Dean Elbaz <elbaz.dean@gmail.com>
Mike Berman <evencode@gmail.com>
Dmitriy Fedorenko <c0va23@gmail.com>
Zach Marcantel <zmarcantel@gmail.com>
James Maloney <jamessagan@gmail.com>
Ashwin Purohit <purohit@gmail.com>
Dan Kinder <dkinder.is.me@gmail.com>
Oliver Beattie <oliver@obeattie.com>
Justin Corpron <jncorpron@gmail.com>
Miles Delahunty <miles.delahunty@gmail.com>
Zach Badgett <zach.badgett@gmail.com>
Maciek Sakrejda <maciek@heroku.com>
Jeff Mitchell <jeffrey.mitchell@gmail.com>
Baptiste Fontaine <b@ptistefontaine.fr>
Matt Heath <matt@mattheath.com>
Jamie Cuthill <jamie.cuthill@gmail.com>
Adrian Casajus <adriancasajus@gmail.com>
John Weldon <johnweldon4@gmail.com>
Adrien Bustany <adrien@bustany.org>
Andrey Smirnov <smirnov.andrey@gmail.com>
Adam Weiner <adamsweiner@gmail.com>
Daniel Cannon <daniel@danielcannon.co.uk>
Johnny Bergström <johnny@joonix.se>
Adriano Orioli <orioli.adriano@gmail.com>
Claudiu Raveica <claudiu.raveica@gmail.com>
Artem Chernyshev <artem.0xD2@gmail.com>
Ference Fu <fym201@msn.com>
LOVOO <opensource@lovoo.com>
nikandfor <nikandfor@gmail.com>
Anthony Woods <awoods@raintank.io>
Alexander Inozemtsev <alexander.inozemtsev@gmail.com>
Rob McColl <rob@robmccoll.com>; <rmccoll@ionicsecurity.com>
Viktor Tönköl <viktor.toenkoel@motionlogic.de>
Ian Lozinski <ian.lozinski@gmail.com>
Michael Highstead <highstead@gmail.com>
Sarah Brown <esbie.is@gmail.com>
Caleb Doxsey <caleb@datadoghq.com>
Frederic Hemery <frederic.hemery@datadoghq.com>
Pekka Enberg <penberg@scylladb.com>
Mark M <m.mim95@gmail.com>
Bartosz Burclaf <burclaf@gmail.com>
Marcus King <marcusking01@gmail.com>
Andrew de Andrade <andrew@deandrade.com.br>
Robert Nix <robert@nicerobot.org>
Nathan Youngman <git@nathany.com>
Charles Law <charles.law@gmail.com>; <claw@conduce.com>
Nathan Davies <nathanjamesdavies@gmail.com>
Bo Blanton <bo.blanton@gmail.com>
Vincent Rischmann <me@vrischmann.me>
Jesse Claven <jesse.claven@gmail.com>
Derrick Wippler <thrawn01@gmail.com>
Leigh McCulloch <leigh@leighmcculloch.com>
Ron Kuris <swcafe@gmail.com>
Raphael Gavache <raphael.gavache@gmail.com>
Yasser Abdolmaleki <yasser@yasser.ca>
Krishnanand Thommandra <devtkrishna@gmail.com>
Blake Atkinson <me@blakeatkinson.com>
Dharmendra Parsaila <d4dharmu@gmail.com>
Nayef Ghattas <nayef.ghattas@datadoghq.com>
Michał Matczuk <mmatczuk@gmail.com>
Ben Krebsbach <ben.krebsbach@gmail.com>
Vivian Mathews <vivian.mathews.3@gmail.com>
Sascha Steinbiss <satta@debian.org>
Seth Rosenblum <seth.t.rosenblum@gmail.com>
Javier Zunzunegui <javier.zunzunegui.b@gmail.com>
Luke Hines <lukehines@protonmail.com>
Zhixin Wen <john.wenzhixin@hotmail.com>
Chang Liu <changliu.it@gmail.com>
Ingo Oeser <nightlyone@gmail.com>
Luke Hines <lukehines@protonmail.com>
Jacob Greenleaf <jacob@jacobgreenleaf.com>
Alex Lourie <alex@instaclustr.com>; <djay.il@gmail.com>
Marco Cadetg <cadetg@gmail.com>
Karl Matthias <karl@matthias.org>
Thomas Meson <zllak@hycik.org>
Martin Sucha <martin.sucha@kiwi.com>; <git@mm.ms47.eu>
Pavel Buchinchik <p.buchinchik@gmail.com>
Rintaro Okamura <rintaro.okamura@gmail.com>
Ivan Boyarkin <ivan.boyarkin@kiwi.com>; <mr.vanboy@gmail.com>
Yura Sokolov <y.sokolov@joom.com>; <funny.falcon@gmail.com>
Jorge Bay <jorgebg@apache.org>
Dmitriy Kozlov <hummerd@mail.ru>
Alexey Romanovsky <alexus1024+gocql@gmail.com>
Jaume Marhuenda Beltran <jaumemarhuenda@gmail.com>
Piotr Dulikowski <piodul@scylladb.com>
Árni Dagur <arni@dagur.eu>
Tushar Das <tushar.das5@gmail.com>
Maxim Vladimirskiy <horkhe@gmail.com>
Bogdan-Ciprian Rusu <bogdanciprian.rusu@crowdstrike.com>
Yuto Doi <yutodoi.seattle@gmail.com>
Krishna Vadali <tejavadali@gmail.com>
Jens-W. Schicke-Uffmann <drahflow@gmx.de>
Ondrej Polakovič <ondrej.polakovic@kiwi.com>
Sergei Karetnikov <sergei.karetnikov@gmail.com>
Stefan Miklosovic <smiklosovic@apache.org>
Adam Burk <amburk@gmail.com>
Valerii Ponomarov <kiparis.kh@gmail.com>
Neal Turett <neal.turett@datadoghq.com>
Doug Schaapveld <djschaap@gmail.com>
Steven Seidman <steven.seidman@datadoghq.com>
Wojciech Przytuła <wojciech.przytula@scylladb.com>
João Reis <joao.reis@datastax.com>
Lauro Ramos Venancio <lauro.venancio@incognia.com>
Dmitry Kropachev <dmitry.kropachev@gmail.com>
Oliver Boyle <pleasedontspamme4321+gocql@gmail.com>
Jackson Fleming <jackson.fleming@instaclustr.com>
Sylwia Szunejko <sylwia.szunejko@scylladb.com>
Karol Baryła <karol.baryla@scylladb.com>
Marcin Mazurek <marcinek.mazurek@gmail.com>
Moguchev Leonid Alekseevich <lmoguchev@ozon.ru>
Julien Lefevre <julien.lefevr@gmail.com>

78
vendor/github.com/gocql/gocql/CONTRIBUTING.md generated vendored Normal file
View File

@@ -0,0 +1,78 @@
# Contributing to gocql
**TL;DR** - this manifesto sets out the bare minimum requirements for submitting a patch to gocql.
This guide outlines the process of landing patches in gocql and the general approach to maintaining the code base.
## Background
The goal of the gocql project is to provide a stable and robust CQL driver for Go. gocql is a community driven project that is coordinated by a small team of core developers.
## Minimum Requirement Checklist
The following is a check list of requirements that need to be satisfied in order for us to merge your patch:
* You should raise a pull request to gocql/gocql on Github
* The pull request has a title that clearly summarizes the purpose of the patch
* The motivation behind the patch is clearly defined in the pull request summary
* Your name and email have been added to the `AUTHORS` file (for copyright purposes)
* The patch will merge cleanly
* The test coverage does not fall below the critical threshold (currently 64%)
* The merge commit passes the regression test suite on Travis
* `go fmt` has been applied to the submitted code
* Notable changes (i.e. new features or changed behavior, bugfixes) are appropriately documented in CHANGELOG.md, functional changes also in godoc
If there are any requirements that can't be reasonably satisfied, please state this either on the pull request or as part of discussion on the mailing list. Where appropriate, the core team may apply discretion and make an exception to these requirements.
## Beyond The Checklist
In addition to stating the hard requirements, there are a bunch of things that we consider when assessing changes to the library. These soft requirements are helpful pointers of how to get a patch landed quicker and with less fuss.
### General QA Approach
The gocql team needs to consider the ongoing maintainability of the library at all times. Patches that look like they will introduce maintenance issues for the team will not be accepted.
Your patch will get merged quicker if you have decent test cases that provide test coverage for the new behavior you wish to introduce.
Unit tests are good, integration tests are even better. An example of a unit test is `marshal_test.go` - this tests the serialization code in isolation. `cassandra_test.go` is an integration test suite that is executed against every version of Cassandra that gocql supports as part of the CI process on Travis.
That said, the point of writing tests is to provide a safety net to catch regressions, so there is no need to go overboard with tests. Remember that the more tests you write, the more code we will have to maintain. So there's a balance to strike there.
### When It's Too Difficult To Automate Testing
There are legitimate examples of where it is infeasible to write a regression test for a change. Never fear, we will still consider the patch and quite possibly accept the change without a test. The gocql team takes a pragmatic approach to testing. At the end of the day, you could be addressing an issue that is too difficult to reproduce in a test suite, but still occurs in a real production app. In this case, your production app is the test case, and we will have to trust that your change is good.
Examples of pull requests that have been accepted without tests include:
* https://github.com/gocql/gocql/pull/181 - this patch would otherwise require a multi-node cluster to be booted as part of the CI build
* https://github.com/gocql/gocql/pull/179 - this bug can only be reproduced under heavy load in certain circumstances
### Sign Off Procedure
Generally speaking, a pull request can get merged by any one of the core gocql team. If your change is minor, chances are that one team member will just go ahead and merge it there and then. As stated earlier, suitable test coverage will increase the likelihood that a single reviewer will assess and merge your change. If your change has no test coverage, or looks like it may have wider implications for the health and stability of the library, the reviewer may elect to refer the change to another team member to achieve consensus before proceeding. Therefore, the tighter and cleaner your patch is, the quicker it will go through the review process.
### Supported Features
gocql is a low level wire driver for Cassandra CQL. By and large, we would like to keep the functional scope of the library as narrow as possible. We think that gocql should be tight and focused, and we will be naturally skeptical of things that could just as easily be implemented in a higher layer. Inevitably you will come across something that could be implemented in a higher layer, save for a minor change to the core API. In this instance, please strike up a conversation with the gocql team. Chances are we will understand what you are trying to achieve and will try to accommodate this in a maintainable way.
### Longer Term Evolution
There are some long term plans for gocql that have to be taken into account when assessing changes. That said, gocql is ultimately a community driven project and we don't have a massive development budget, so sometimes the long term view might need to be de-prioritized ahead of short term changes.
## Officially Supported Server Versions
Currently, the officially supported versions of the Cassandra server include:
* 1.2.18
* 2.0.9
Chances are that gocql will work with many other versions. If you would like us to support a particular version of Cassandra, please start a conversation about what version you'd like us to consider. We are more likely to accept a new version if you help out by extending the regression suite to cover the new version to be supported.
## The Core Dev Team
The core development team includes:
* tux21b
* phillipCouto
* Zariel
* 0x6e6562

27
vendor/github.com/gocql/gocql/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,27 @@
Copyright (c) 2016, The Gocql authors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

5
vendor/github.com/gocql/gocql/Makefile generated vendored Normal file
View File

@@ -0,0 +1,5 @@
# Makefile to run the Docker cleanup script
clean-old-temporary-docker-images:
@echo "Running Docker Hub image cleanup script..."
python ci/clean-old-temporary-docker-images.py

242
vendor/github.com/gocql/gocql/README.md generated vendored Normal file
View File

@@ -0,0 +1,242 @@
<div align="center">
![Build Passing](https://github.com/scylladb/gocql/workflows/Build/badge.svg)
[![Read the Fork Driver Docs](https://img.shields.io/badge/Read_the_Docs-pkg_go-blue)](https://pkg.go.dev/github.com/scylladb/gocql#section-documentation)
[![Protocol Specs](https://img.shields.io/badge/Protocol_Specs-ScyllaDB_Docs-blue)](https://github.com/scylladb/scylladb/blob/master/docs/dev/protocol-extensions.md)
</div>
<h1 align="center">
Scylla Shard-Aware Fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver)
</h1>
<img src="./.github/assets/logo.svg" width="200" align="left" />
This is a fork of [apache/cassandra-gocql-driver](https://github.com/apache/cassandra-gocql-driver) package that we created at Scylla.
It contains extensions to tokenAwareHostPolicy supported by the Scylla 2.3 and onwards.
It allows driver to select a connection to a particular shard on a host based on the token.
This eliminates passing data between shards and significantly reduces latency.
There are open pull requests to merge the functionality to the upstream project:
* [gocql/gocql#1210](https://github.com/gocql/gocql/pull/1210)
* [gocql/gocql#1211](https://github.com/gocql/gocql/pull/1211).
It also provides support for shard aware ports, a faster way to connect to all shards, details available in [blogpost](https://www.scylladb.com/2021/04/27/connect-faster-to-scylla-with-a-shard-aware-port/).
---
### Table of Contents
- [1. Sunsetting Model](#1-sunsetting-model)
- [2. Installation](#2-installation)
- [3. Quick Start](#3-quick-start)
- [4. Data Types](#4-data-types)
- [5. Configuration](#5-configuration)
- [5.1 Shard-aware port](#51-shard-aware-port)
- [5.2 Iterator](#52-iterator)
- [6. Contributing](#6-contributing)
## 1. Sunsetting Model
> [!WARNING]
> In general, the gocql team will focus on supporting the current and previous versions of Go. gocql may still work with older versions of Go, but official support for these versions will have been sunset.
## 2. Installation
This is a drop-in replacement to gocql, it reuses the `github.com/gocql/gocql` import path.
Add the following line to your project `go.mod` file.
```mod
replace github.com/gocql/gocql => github.com/scylladb/gocql latest
```
and run
```sh
go mod tidy
```
to evaluate `latest` to a concrete tag.
Your project now uses the Scylla driver fork, make sure you are using the `TokenAwareHostPolicy` to enable the shard-awareness, continue reading for details.
## 3. Quick Start
Spawn a ScyllaDB Instance using Docker Run command:
```sh
docker run --name node1 --network your-network -p "9042:9042" -d scylladb/scylla:6.1.2 \
--overprovisioned 1 \
--smp 1
```
Then, create a new connection using ScyllaDB GoCQL following the example below:
```go
package main
import (
"fmt"
"github.com/gocql/gocql"
)
func main() {
var cluster = gocql.NewCluster("localhost:9042")
var session, err = cluster.CreateSession()
if err != nil {
panic("Failed to connect to cluster")
}
defer session.Close()
var query = session.Query("SELECT * FROM system.clients")
if rows, err := query.Iter().SliceMap(); err == nil {
for _, row := range rows {
fmt.Printf("%v\n", row)
}
} else {
panic("Query error: " + err.Error())
}
}
```
## 4. Data Types
Here's an list of all ScyllaDB Types reflected in the GoCQL environment:
| ScyllaDB Type | Go Type |
| ---------------- | ------------------ |
| `ascii` | `string` |
| `bigint` | `int64` |
| `blob` | `[]byte` |
| `boolean` | `bool` |
| `date` | `time.Time` |
| `decimal` | `inf.Dec` |
| `double` | `float64` |
| `duration` | `gocql.Duration` |
| `float` | `float32` |
| `uuid` | `gocql.UUID` |
| `int` | `int32` |
| `inet` | `string` |
| `list<int>` | `[]int32` |
| `map<int, text>` | `map[int32]string` |
| `set<int>` | `[]int32` |
| `smallint` | `int16` |
| `text` | `string` |
| `time` | `time.Duration` |
| `timestamp` | `time.Time` |
| `timeuuid` | `gocql.UUID` |
| `tinyint` | `int8` |
| `varchar` | `string` |
| `varint` | `int64` |
## 5. Configuration
In order to make shard-awareness work, token aware host selection policy has to be enabled.
Please make sure that the gocql configuration has `PoolConfig.HostSelectionPolicy` properly set like in the example below.
__When working with a Scylla cluster, `PoolConfig.NumConns` option has no effect - the driver opens one connection for each shard and completely ignores this option.__
```go
c := gocql.NewCluster(hosts...)
// Enable token aware host selection policy, if using multi-dc cluster set a local DC.
fallback := gocql.RoundRobinHostPolicy()
if localDC != "" {
fallback = gocql.DCAwareRoundRobinPolicy(localDC)
}
c.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(fallback)
// If using multi-dc cluster use the "local" consistency levels.
if localDC != "" {
c.Consistency = gocql.LocalQuorum
}
// When working with a Scylla cluster the driver always opens one connection per shard, so `NumConns` is ignored.
// c.NumConns = 4
```
### 5.1 Shard-aware port
This version of gocql supports a more robust method of establishing connection for each shard by using _shard aware port_ for native transport.
It greatly reduces time and the number of connections needed to establish a connection per shard in some cases - ex. when many clients connect at once, or when there are non-shard-aware clients connected to the same cluster.
If you are using a custom Dialer and if your nodes expose the shard-aware port, it is highly recommended to update it so that it uses a specific source port when connecting.
* If you are using a custom `net.Dialer`, you can make your dialer honor the source port by wrapping it in a `gocql.ScyllaShardAwareDialer`:
```go
oldDialer := net.Dialer{...}
clusterConfig.Dialer := &gocql.ScyllaShardAwareDialer{oldDialer}
```
* If you are using a custom type implementing `gocql.Dialer`, you can get the source port by using the `gocql.ScyllaGetSourcePort` function.
An example:
```go
func (d *myDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
sourcePort := gocql.ScyllaGetSourcePort(ctx)
localAddr, err := net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
if err != nil {
return nil, err
}
d := &net.Dialer{LocalAddr: localAddr}
return d.DialContext(ctx, network, addr)
}
```
The source port might be already bound by another connection on your system.
In such case, you should return an appropriate error so that the driver can retry with a different port suitable for the shard it tries to connect to.
* If you are using `net.Dialer.DialContext`, this function will return an error in case the source port is unavailable, and you can just return that error from your custom `Dialer`.
* Otherwise, if you detect that the source port is unavailable, you can either return `gocql.ErrScyllaSourcePortAlreadyInUse` or `syscall.EADDRINUSE`.
For this feature to work correctly, you need to make sure the following conditions are met:
* Your cluster nodes are configured to listen on the shard-aware port (`native_shard_aware_transport_port` option),
* Your cluster nodes are not behind a NAT which changes source ports,
* If you have a custom Dialer, it connects from the correct source port (see the guide above).
The feature is designed to gracefully fall back to the using the non-shard-aware port when it detects that some of the above conditions are not met.
The driver will print a warning about misconfigured address translation if it detects it.
Issues with shard-aware port not being reachable are not reported in non-debug mode, because there is no way to detect it without false positives.
If you suspect that this feature is causing you problems, you can completely disable it by setting the `ClusterConfig.DisableShardAwarePort` flag to true.
### 5.2 Iterator
Paging is a way to parse large result sets in smaller chunks.
The driver provides an iterator to simplify this process.
Use `Query.Iter()` to obtain iterator:
```go
iter := session.Query("SELECT id, value FROM my_table WHERE id > 100 AND id < 10000").Iter()
var results []int
var id, value int
for !iter.Scan(&id, &value) {
if id%2 == 0 {
results = append(results, value)
}
}
if err := iter.Close(); err != nil {
// handle error
}
```
In case of range and `ALLOW FILTERING` queries server can send empty responses for some pages.
That is why you should never consider empty response as the end of the result set.
Always check `iter.Scan()` result to know if there are more results, or `Iter.LastPage()` to know if the last page was reached.
## 6. Contributing
If you have any interest to be contributing in this GoCQL Fork, please read the [CONTRIBUTING.md](CONTRIBUTING.md) before initialize any Issue or Pull Request.

26
vendor/github.com/gocql/gocql/address_translators.go generated vendored Normal file
View File

@@ -0,0 +1,26 @@
package gocql
import "net"
// AddressTranslator provides a way to translate node addresses (and ports) that are
// discovered or received as a node event. This can be useful in an ec2 environment,
// for instance, to translate public IPs to private IPs.
type AddressTranslator interface {
// Translate will translate the provided address and/or port to another
// address and/or port. If no translation is possible, Translate will return the
// address and port provided to it.
Translate(addr net.IP, port int) (net.IP, int)
}
type AddressTranslatorFunc func(addr net.IP, port int) (net.IP, int)
func (fn AddressTranslatorFunc) Translate(addr net.IP, port int) (net.IP, int) {
return fn(addr, port)
}
// IdentityTranslator will do nothing but return what it was provided. It is essentially a no-op.
func IdentityTranslator() AddressTranslator {
return AddressTranslatorFunc(func(addr net.IP, port int) (net.IP, int) {
return addr, port
})
}

541
vendor/github.com/gocql/gocql/cluster.go generated vendored Normal file
View File

@@ -0,0 +1,541 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
"net"
"sync/atomic"
"time"
)
const defaultDriverName = "ScyllaDB GoCQL Driver"
// PoolConfig configures the connection pool used by the driver, it defaults to
// using a round-robin host selection policy and a round-robin connection selection
// policy for each host.
type PoolConfig struct {
// HostSelectionPolicy sets the policy for selecting which host to use for a
// given query (default: RoundRobinHostPolicy())
// It is not supported to use a single HostSelectionPolicy in multiple sessions
// (even if you close the old session before using in a new session).
HostSelectionPolicy HostSelectionPolicy
}
func (p PoolConfig) buildPool(session *Session) *policyConnPool {
return newPolicyConnPool(session)
}
// ClusterConfig is a struct to configure the default cluster implementation
// of gocql. It has a variety of attributes that can be used to modify the
// behavior to fit the most common use cases. Applications that require a
// different setup must implement their own cluster.
type ClusterConfig struct {
// addresses for the initial connections. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
Hosts []string
// CQL version (default: 3.0.0)
CQLVersion string
// ProtoVersion sets the version of the native protocol to use, this will
// enable features in the driver for specific protocol versions, generally this
// should be set to a known version (2,3,4) for the cluster being connected to.
//
// If it is 0 or unset (the default) then the driver will attempt to discover the
// highest supported protocol for the cluster. In clusters with nodes of different
// versions the protocol selected is not defined (ie, it can be any of the supported in the cluster)
ProtoVersion int
// Timeout limits the time spent on the client side while executing a query.
// Specifically, query or batch execution will return an error if the client does not receive a response
// from the server within the Timeout period.
// Timeout is also used to configure the read timeout on the underlying network connection.
// Client Timeout should always be higher than the request timeouts configured on the server,
// so that retries don't overload the server.
// Timeout has a default value of 11 seconds, which is higher than default server timeout for most query types.
// Timeout is not applied to requests during initial connection setup, see ConnectTimeout.
Timeout time.Duration
// ConnectTimeout limits the time spent during connection setup.
// During initial connection setup, internal queries, AUTH requests will return an error if the client
// does not receive a response within the ConnectTimeout period.
// ConnectTimeout is applied to the connection setup queries independently.
// ConnectTimeout also limits the duration of dialing a new TCP connection
// in case there is no Dialer nor HostDialer configured.
// ConnectTimeout has a default value of 11 seconds.
ConnectTimeout time.Duration
// WriteTimeout limits the time the driver waits to write a request to a network connection.
// WriteTimeout should be lower than or equal to Timeout.
// WriteTimeout defaults to the value of Timeout.
WriteTimeout time.Duration
// Port used when dialing.
// Default: 9042
Port int
// Initial keyspace. Optional.
Keyspace string
// The size of the connection pool for each host.
// The pool filling runs in separate gourutine during the session initialization phase.
// gocql will always try to get 1 connection on each host pool
// during session initialization AND it will attempt
// to fill each pool afterward asynchronously if NumConns > 1.
// Notice: There is no guarantee that pool filling will be finished in the initialization phase.
// Also, it describes a maximum number of connections at the same time.
// Default: 2
NumConns int
// Maximum number of inflight requests allowed per connection.
// Default: 32768 for CQL v3 and newer
// Default: 128 for older CQL versions
MaxRequestsPerConn int
// Default consistency level.
// Default: Quorum
Consistency Consistency
// Compression algorithm.
// Default: nil
Compressor Compressor
// Default: nil
Authenticator Authenticator
WarningsHandlerBuilder WarningHandlerBuilder
// An Authenticator factory. Can be used to create alternative authenticators.
// Default: nil
AuthProvider func(h *HostInfo) (Authenticator, error)
// Default retry policy to use for queries.
// Default: no retries.
RetryPolicy RetryPolicy
// ConvictionPolicy decides whether to mark host as down based on the error and host info.
// Default: SimpleConvictionPolicy
ConvictionPolicy ConvictionPolicy
// Default reconnection policy to use for reconnecting before trying to mark host as down.
ReconnectionPolicy ReconnectionPolicy
// A reconnection policy to use for reconnecting when connecting to the cluster first time.
InitialReconnectionPolicy ReconnectionPolicy
// The keepalive period to use, enabled if > 0 (default: 15 seconds)
// SocketKeepalive is used to set up the default dialer and is ignored if Dialer or HostDialer is provided.
SocketKeepalive time.Duration
// Maximum cache size for prepared statements globally for gocql.
// Default: 1000
MaxPreparedStmts int
// Maximum cache size for query info about statements for each session.
// Default: 1000
MaxRoutingKeyInfo int
// Default page size to use for created sessions.
// Default: 5000
PageSize int
// Consistency for the serial part of queries, values can be either SERIAL or LOCAL_SERIAL.
// Default: unset
SerialConsistency SerialConsistency
// SslOpts configures TLS use when HostDialer is not set.
// SslOpts is ignored if HostDialer is set.
SslOpts *SslOptions
actualSslOpts atomic.Value
// Sends a client side timestamp for all requests which overrides the timestamp at which it arrives at the server.
// Default: true, only enabled for protocol 3 and above.
DefaultTimestamp bool
// The name of the driver that is going to be reported to the server.
// Default: "ScyllaDB GoLang Driver"
DriverName string
// The version of the driver that is going to be reported to the server.
// Defaulted to current library version
DriverVersion string
// PoolConfig configures the underlying connection pool, allowing the
// configuration of host selection and connection selection policies.
PoolConfig PoolConfig
// If not zero, gocql attempt to reconnect known DOWN nodes in every ReconnectInterval.
ReconnectInterval time.Duration
// The maximum amount of time to wait for schema agreement in a cluster after
// receiving a schema change frame. (default: 60s)
MaxWaitSchemaAgreement time.Duration
// HostFilter will filter all incoming events for host, any which don't pass
// the filter will be ignored. If set will take precedence over any options set
// via Discovery
HostFilter HostFilter
// AddressTranslator will translate addresses found on peer discovery and/or
// node change events.
AddressTranslator AddressTranslator
// If IgnorePeerAddr is true and the address in system.peers does not match
// the supplied host by either initial hosts or discovered via events then the
// host will be replaced with the supplied address.
//
// For example if an event comes in with host=10.0.0.1 but when looking up that
// address in system.local or system.peers returns 127.0.0.1, the peer will be
// set to 10.0.0.1 which is what will be used to connect to.
IgnorePeerAddr bool
// If DisableInitialHostLookup then the driver will not attempt to get host info
// from the system.peers table, this will mean that the driver will connect to
// hosts supplied and will not attempt to lookup the hosts information, this will
// mean that data_centre, rack and token information will not be available and as
// such host filtering and token aware query routing will not be available.
DisableInitialHostLookup bool
// Configure events the driver will register for
Events struct {
// disable registering for status events (node up/down)
DisableNodeStatusEvents bool
// disable registering for topology events (node added/removed/moved)
DisableTopologyEvents bool
// disable registering for schema events (keyspace/table/function removed/created/updated)
DisableSchemaEvents bool
}
// DisableSkipMetadata will override the internal result metadata cache so that the driver does not
// send skip_metadata for queries, this means that the result will always contain
// the metadata to parse the rows and will not reuse the metadata from the prepared
// statement.
//
// See https://issues.apache.org/jira/browse/CASSANDRA-10786
// See https://github.com/scylladb/scylladb/issues/20860
//
// Default: true
DisableSkipMetadata bool
// QueryObserver will set the provided query observer on all queries created from this session.
// Use it to collect metrics / stats from queries by providing an implementation of QueryObserver.
QueryObserver QueryObserver
// BatchObserver will set the provided batch observer on all queries created from this session.
// Use it to collect metrics / stats from batch queries by providing an implementation of BatchObserver.
BatchObserver BatchObserver
// ConnectObserver will set the provided connect observer on all queries
// created from this session.
ConnectObserver ConnectObserver
// FrameHeaderObserver will set the provided frame header observer on all frames' headers created from this session.
// Use it to collect metrics / stats from frames by providing an implementation of FrameHeaderObserver.
FrameHeaderObserver FrameHeaderObserver
// StreamObserver will be notified of stream state changes.
// This can be used to track in-flight protocol requests and responses.
StreamObserver StreamObserver
// Default idempotence for queries
DefaultIdempotence bool
// The time to wait for frames before flushing the frames connection to Cassandra.
// Can help reduce syscall overhead by making less calls to write. Set to 0 to
// disable.
//
// (default: 200 microseconds)
WriteCoalesceWaitTime time.Duration
// Dialer will be used to establish all connections created for this Cluster.
// If not provided, a default dialer configured with ConnectTimeout will be used.
// Dialer is ignored if HostDialer is provided.
Dialer Dialer
// HostDialer will be used to establish all connections for this Cluster.
// Unlike Dialer, HostDialer is responsible for setting up the entire connection, including the TLS session.
// To support shard-aware port, HostDialer should implement ShardDialer.
// If not provided, Dialer will be used instead.
HostDialer HostDialer
// DisableShardAwarePort will prevent the driver from connecting to Scylla's shard-aware port,
// even if there are nodes in the cluster that support it.
//
// It is generally recommended to leave this option turned off because gocql can use
// the shard-aware port to make the process of establishing more robust.
// However, if you have a cluster with nodes which expose shard-aware port
// but the port is unreachable due to network configuration issues, you can use
// this option to work around the issue. Set it to true only if you neither can fix
// your network nor disable shard-aware port on your nodes.
DisableShardAwarePort bool
// Logger for this ClusterConfig.
// If not specified, defaults to the global gocql.Logger.
Logger StdLogger
// The timeout for the requests to the schema tables. (default: 60s)
MetadataSchemaRequestTimeout time.Duration
// internal config for testing
disableControlConn bool
disableInit bool
}
type Dialer interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
}
// NewCluster generates a new config for the default cluster implementation.
//
// The supplied hosts are used to initially connect to the cluster then the rest of
// the ring will be automatically discovered. It is recommended to use the value set in
// the Cassandra config for broadcast_address or listen_address, an IP address not
// a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified
// resolves to more than 1 IP address then the driver may connect multiple times to
// the same host, and will not mark the node being down or up from events.
func NewCluster(hosts ...string) *ClusterConfig {
cfg := &ClusterConfig{
Hosts: hosts,
CQLVersion: "3.0.0",
Timeout: 11 * time.Second,
ConnectTimeout: 11 * time.Second,
Port: 9042,
NumConns: 2,
Consistency: Quorum,
MaxPreparedStmts: defaultMaxPreparedStmts,
MaxRoutingKeyInfo: 1000,
PageSize: 5000,
DefaultTimestamp: true,
DriverName: defaultDriverName,
DriverVersion: defaultDriverVersion,
MaxWaitSchemaAgreement: 60 * time.Second,
ReconnectInterval: 60 * time.Second,
ConvictionPolicy: &SimpleConvictionPolicy{},
ReconnectionPolicy: &ConstantReconnectionPolicy{MaxRetries: 3, Interval: 1 * time.Second},
InitialReconnectionPolicy: &NoReconnectionPolicy{},
SocketKeepalive: 15 * time.Second,
WriteCoalesceWaitTime: 200 * time.Microsecond,
MetadataSchemaRequestTimeout: 60 * time.Second,
DisableSkipMetadata: true,
WarningsHandlerBuilder: DefaultWarningHandlerBuilder,
}
return cfg
}
func (cfg *ClusterConfig) logger() StdLogger {
if cfg.Logger == nil {
return Logger
}
return cfg.Logger
}
// CreateSession initializes the cluster based on this config and returns a
// session object that can be used to interact with the database.
func (cfg *ClusterConfig) CreateSession() (*Session, error) {
return NewSession(*cfg)
}
func (cfg *ClusterConfig) CreateSessionNonBlocking() (*Session, error) {
return NewSessionNonBlocking(*cfg)
}
// translateAddressPort is a helper method that will use the given AddressTranslator
// if defined, to translate the given address and port into a possibly new address
// and port, If no AddressTranslator or if an error occurs, the given address and
// port will be returned.
func (cfg *ClusterConfig) translateAddressPort(addr net.IP, port int) (net.IP, int) {
if cfg.AddressTranslator == nil || len(addr) == 0 {
return addr, port
}
newAddr, newPort := cfg.AddressTranslator.Translate(addr, port)
if gocqlDebug {
cfg.logger().Printf("gocql: translating address '%v:%d' to '%v:%d'", addr, port, newAddr, newPort)
}
return newAddr, newPort
}
func (cfg *ClusterConfig) filterHost(host *HostInfo) bool {
return !(cfg.HostFilter == nil || cfg.HostFilter.Accept(host))
}
func (cfg *ClusterConfig) ValidateAndInitSSL() error {
if cfg.SslOpts == nil {
return nil
}
actualTLSConfig, err := setupTLSConfig(cfg.SslOpts)
if err != nil {
return fmt.Errorf("failed to initialize ssl configuration: %s", err.Error())
}
cfg.actualSslOpts.Store(actualTLSConfig)
return nil
}
func (cfg *ClusterConfig) getActualTLSConfig() *tls.Config {
val, ok := cfg.actualSslOpts.Load().(*tls.Config)
if !ok {
return nil
}
return val.Clone()
}
func (cfg *ClusterConfig) Validate() error {
if len(cfg.Hosts) == 0 {
return ErrNoHosts
}
if cfg.Authenticator != nil && cfg.AuthProvider != nil {
return errors.New("Can't use both Authenticator and AuthProvider in cluster config.")
}
if cfg.InitialReconnectionPolicy == nil {
return errors.New("InitialReconnectionPolicy is nil")
}
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("InitialReconnectionPolicy.GetMaxRetries returns negative number")
}
if cfg.ReconnectionPolicy == nil {
return errors.New("ReconnectionPolicy is nil")
}
if cfg.InitialReconnectionPolicy.GetMaxRetries() <= 0 {
return errors.New("ReconnectionPolicy.GetMaxRetries returns negative number")
}
if cfg.PageSize < 0 {
return errors.New("PageSize should be positive number or zero")
}
if cfg.MaxRoutingKeyInfo < 0 {
return errors.New("MaxRoutingKeyInfo should be positive number or zero")
}
if cfg.MaxPreparedStmts < 0 {
return errors.New("MaxPreparedStmts should be positive number or zero")
}
if cfg.SocketKeepalive < 0 {
return errors.New("SocketKeepalive should be positive time.Duration or zero")
}
if cfg.MaxRequestsPerConn < 0 {
return errors.New("MaxRequestsPerConn should be positive number or zero")
}
if cfg.NumConns < 0 {
return errors.New("NumConns should be positive non-zero number or zero")
}
if cfg.Port <= 0 || cfg.Port > 65535 {
return errors.New("Port should be a valid port number: a number between 1 and 65535")
}
if cfg.WriteTimeout < 0 {
return errors.New("WriteTimeout should be positive time.Duration or zero")
}
if cfg.Timeout < 0 {
return errors.New("Timeout should be positive time.Duration or zero")
}
if cfg.ConnectTimeout < 0 {
return errors.New("ConnectTimeout should be positive time.Duration or zero")
}
if cfg.MetadataSchemaRequestTimeout < 0 {
return errors.New("MetadataSchemaRequestTimeout should be positive time.Duration or zero")
}
if cfg.WriteCoalesceWaitTime < 0 {
return errors.New("WriteCoalesceWaitTime should be positive time.Duration or zero")
}
if cfg.ReconnectInterval < 0 {
return errors.New("ReconnectInterval should be positive time.Duration or zero")
}
if cfg.MaxWaitSchemaAgreement < 0 {
return errors.New("MaxWaitSchemaAgreement should be positive time.Duration or zero")
}
if cfg.ProtoVersion < 0 {
return errors.New("ProtoVersion should be positive number or zero")
}
if !cfg.DisableSkipMetadata {
Logger.Println("warning: enabling skipping metadata can lead to unpredictible results when executing query and altering columns involved in the query.")
}
return cfg.ValidateAndInitSSL()
}
var (
ErrNoHosts = errors.New("no hosts provided")
ErrNoConnectionsStarted = errors.New("no connections were made when creating the session")
ErrHostQueryFailed = errors.New("unable to populate Hosts")
)
func setupTLSConfig(sslOpts *SslOptions) (*tls.Config, error) {
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | true | verify host
// Config is nil | false | do not verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
var tlsConfig *tls.Config
if sslOpts.Config == nil {
tlsConfig = &tls.Config{
InsecureSkipVerify: !sslOpts.EnableHostVerification,
}
} else {
// use clone to avoid race.
tlsConfig = sslOpts.Config.Clone()
}
if tlsConfig.InsecureSkipVerify && sslOpts.EnableHostVerification {
tlsConfig.InsecureSkipVerify = false
}
// ca cert is optional
if sslOpts.CaPath != "" {
if tlsConfig.RootCAs == nil {
tlsConfig.RootCAs = x509.NewCertPool()
}
pem, err := ioutil.ReadFile(sslOpts.CaPath)
if err != nil {
return nil, fmt.Errorf("unable to open CA certs: %v", err)
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("failed parsing or CA certs")
}
}
if sslOpts.CertPath != "" || sslOpts.KeyPath != "" {
mycert, err := tls.LoadX509KeyPair(sslOpts.CertPath, sslOpts.KeyPath)
if err != nil {
return nil, fmt.Errorf("unable to load X509 key pair: %v", err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, mycert)
}
return tlsConfig, nil
}

29
vendor/github.com/gocql/gocql/compressor.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
package gocql
import (
"github.com/klauspost/compress/s2"
)
type Compressor interface {
Name() string
Encode(data []byte) ([]byte, error)
Decode(data []byte) ([]byte, error)
}
// SnappyCompressor implements the Compressor interface and can be used to
// compress incoming and outgoing frames. It uses S2 compression algorithm
// that is compatible with snappy and aims for high throughput, which is why
// it features concurrent compression for bigger payloads.
type SnappyCompressor struct{}
func (s SnappyCompressor) Name() string {
return "snappy"
}
func (s SnappyCompressor) Encode(data []byte) ([]byte, error) {
return s2.EncodeSnappy(nil, data), nil
}
func (s SnappyCompressor) Decode(data []byte) ([]byte, error) {
return s2.Decode(nil, data)
}

1964
vendor/github.com/gocql/gocql/conn.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

562
vendor/github.com/gocql/gocql/connectionpool.go generated vendored Normal file
View File

@@ -0,0 +1,562 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"fmt"
"math/rand"
"net"
"sync"
"time"
"github.com/gocql/gocql/debounce"
)
// interface to implement to receive the host information
type SetHosts interface {
SetHosts(hosts []*HostInfo)
}
// interface to implement to receive the partitioner value
type SetPartitioner interface {
SetPartitioner(partitioner string)
}
// interface to implement to receive the tablets value
type SetTablets interface {
SetTablets(tablets TabletInfoList)
}
type policyConnPool struct {
session *Session
port int
numConns int
keyspace string
mu sync.RWMutex
hostConnPools map[string]*hostConnPool
}
func connConfig(cfg *ClusterConfig) (*ConnConfig, error) {
hostDialer := cfg.HostDialer
if hostDialer == nil {
dialer := cfg.Dialer
if dialer == nil {
d := net.Dialer{
Timeout: cfg.ConnectTimeout,
}
if cfg.SocketKeepalive > 0 {
d.KeepAlive = cfg.SocketKeepalive
}
dialer = &ScyllaShardAwareDialer{d}
}
hostDialer = &scyllaDialer{
dialer: dialer,
logger: cfg.logger(),
tlsConfig: cfg.getActualTLSConfig(),
cfg: cfg,
}
}
return &ConnConfig{
ProtoVersion: cfg.ProtoVersion,
CQLVersion: cfg.CQLVersion,
Timeout: cfg.Timeout,
WriteTimeout: cfg.WriteTimeout,
ConnectTimeout: cfg.ConnectTimeout,
Dialer: cfg.Dialer,
HostDialer: hostDialer,
Compressor: cfg.Compressor,
Authenticator: cfg.Authenticator,
AuthProvider: cfg.AuthProvider,
Keepalive: cfg.SocketKeepalive,
Logger: cfg.logger(),
tlsConfig: cfg.getActualTLSConfig(),
}, nil
}
func newPolicyConnPool(session *Session) *policyConnPool {
// create the pool
pool := &policyConnPool{
session: session,
port: session.cfg.Port,
numConns: session.cfg.NumConns,
keyspace: session.cfg.Keyspace,
hostConnPools: map[string]*hostConnPool{},
}
return pool
}
func (p *policyConnPool) SetHosts(hosts []*HostInfo) {
p.mu.Lock()
defer p.mu.Unlock()
toRemove := make(map[string]struct{})
for hostID := range p.hostConnPools {
toRemove[hostID] = struct{}{}
}
pools := make(chan *hostConnPool)
createCount := 0
for _, host := range hosts {
if !host.IsUp() {
// don't create a connection pool for a down host
continue
}
hostID := host.HostID()
if _, exists := p.hostConnPools[hostID]; exists {
// still have this host, so don't remove it
delete(toRemove, hostID)
continue
}
createCount++
go func(host *HostInfo) {
// create a connection pool for the host
pools <- newHostConnPool(
p.session,
host,
p.port,
p.numConns,
p.keyspace,
)
}(host)
}
// add created pools
for createCount > 0 {
pool := <-pools
createCount--
if pool.Size() > 0 {
// add pool only if there a connections available
p.hostConnPools[pool.host.HostID()] = pool
}
}
for addr := range toRemove {
pool := p.hostConnPools[addr]
delete(p.hostConnPools, addr)
go pool.Close()
}
}
func (p *policyConnPool) InFlight() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.InFlight()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) Size() int {
p.mu.RLock()
count := 0
for _, pool := range p.hostConnPools {
count += pool.Size()
}
p.mu.RUnlock()
return count
}
func (p *policyConnPool) getPool(host *HostInfo) (pool *hostConnPool, ok bool) {
hostID := host.HostID()
p.mu.RLock()
pool, ok = p.hostConnPools[hostID]
p.mu.RUnlock()
return
}
func (p *policyConnPool) Close() {
p.mu.Lock()
defer p.mu.Unlock()
// close the pools
for addr, pool := range p.hostConnPools {
delete(p.hostConnPools, addr)
pool.Close()
}
}
func (p *policyConnPool) addHost(host *HostInfo) {
hostID := host.HostID()
p.mu.Lock()
pool, ok := p.hostConnPools[hostID]
if !ok {
pool = newHostConnPool(
p.session,
host,
host.Port(), // TODO: if port == 0 use pool.port?
p.numConns,
p.keyspace,
)
p.hostConnPools[hostID] = pool
}
p.mu.Unlock()
pool.fill_debounce()
}
func (p *policyConnPool) removeHost(hostID string) {
p.mu.Lock()
pool, ok := p.hostConnPools[hostID]
if !ok {
p.mu.Unlock()
return
}
delete(p.hostConnPools, hostID)
p.mu.Unlock()
go pool.Close()
}
// hostConnPool is a connection pool for a single host.
// Connection selection is based on a provided ConnSelectionPolicy
type hostConnPool struct {
session *Session
host *HostInfo
size int
keyspace string
// protection for connPicker, closed, filling
mu sync.RWMutex
connPicker ConnPicker
closed bool
filling bool
debouncer *debounce.SimpleDebouncer
logger StdLogger
}
func (h *hostConnPool) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
size, _ := h.connPicker.Size()
return fmt.Sprintf("[filling=%v closed=%v conns=%v size=%v host=%v]",
h.filling, h.closed, size, h.size, h.host)
}
func newHostConnPool(session *Session, host *HostInfo, port, size int,
keyspace string) *hostConnPool {
pool := &hostConnPool{
session: session,
host: host,
size: size,
keyspace: keyspace,
connPicker: nopConnPicker{},
filling: false,
closed: false,
logger: session.logger,
debouncer: debounce.NewSimpleDebouncer(),
}
// the pool is not filled or connected
return pool
}
// Pick a connection from this connection pool for the given query.
func (pool *hostConnPool) Pick(token Token, qry ExecutableQuery) *Conn {
pool.mu.RLock()
defer pool.mu.RUnlock()
if pool.closed {
return nil
}
size, missing := pool.connPicker.Size()
if missing > 0 {
// try to fill the pool
go pool.fill_debounce()
if size == 0 {
return nil
}
}
return pool.connPicker.Pick(token, qry)
}
// Size returns the number of connections currently active in the pool
func (pool *hostConnPool) Size() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
size, _ := pool.connPicker.Size()
return size
}
// Size returns the number of connections currently active in the pool
func (pool *hostConnPool) InFlight() int {
pool.mu.RLock()
defer pool.mu.RUnlock()
size := pool.connPicker.InFlight()
return size
}
// Close the connection pool
func (pool *hostConnPool) Close() {
pool.mu.Lock()
defer pool.mu.Unlock()
if !pool.closed {
pool.connPicker.Close()
}
pool.closed = true
}
// Fill the connection pool
func (pool *hostConnPool) fill() {
pool.mu.RLock()
// avoid filling a closed pool, or concurrent filling
if pool.closed || pool.filling {
pool.mu.RUnlock()
return
}
// determine the filling work to be done
startCount, fillCount := pool.connPicker.Size()
// avoid filling a full (or overfull) pool
if fillCount <= 0 {
pool.mu.RUnlock()
return
}
// switch from read to write lock
pool.mu.RUnlock()
pool.mu.Lock()
startCount, fillCount = pool.connPicker.Size()
if pool.closed || pool.filling || fillCount <= 0 {
// looks like another goroutine already beat this
// goroutine to the filling
pool.mu.Unlock()
return
}
// ok fill the pool
pool.filling = true
// allow others to access the pool while filling
pool.mu.Unlock()
// only this goroutine should make calls to fill/empty the pool at this
// point until after this routine or its subordinates calls
// fillingStopped
// fill only the first connection synchronously
if startCount == 0 {
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
// probably unreachable host
pool.fillingStopped(err)
return
}
// notify the session that this node is connected
go pool.session.handleNodeConnected(pool.host)
// filled one, let's reload it to see if it has changed
pool.mu.RLock()
_, fillCount = pool.connPicker.Size()
pool.mu.RUnlock()
}
// fill the rest of the pool asynchronously
go func() {
err := pool.connectMany(fillCount)
// mark the end of filling
pool.fillingStopped(err)
if err == nil && startCount > 0 {
// notify the session that this node is connected again
go pool.session.handleNodeConnected(pool.host)
}
}()
}
func (pool *hostConnPool) fill_debounce() {
pool.debouncer.Debounce(pool.fill)
}
func (pool *hostConnPool) logConnectErr(err error) {
if opErr, ok := err.(*net.OpError); ok && (opErr.Op == "dial" || opErr.Op == "read") {
// connection refused
// these are typical during a node outage so avoid log spam.
if gocqlDebug {
pool.logger.Printf("unable to dial %q: %v\n", pool.host, err)
}
} else if err != nil {
// unexpected error
pool.logger.Printf("error: failed to connect to %q due to error: %v", pool.host, err)
}
}
// transition back to a not-filling state.
func (pool *hostConnPool) fillingStopped(err error) {
if err != nil {
if gocqlDebug {
pool.logger.Printf("gocql: filling stopped %q: %v\n", pool.host.ConnectAddress(), err)
}
// wait for some time to avoid back-to-back filling
// this provides some time between failed attempts
// to fill the pool for the host to recover
time.Sleep(time.Duration(rand.Int31n(100)+31) * time.Millisecond)
}
pool.mu.Lock()
pool.filling = false
count, _ := pool.connPicker.Size()
host := pool.host
port := pool.host.Port()
pool.mu.Unlock()
// if we errored and the size is now zero, make sure the host is marked as down
// see https://github.com/gocql/gocql/issues/1614
if gocqlDebug {
pool.logger.Printf("gocql: conns of pool after stopped %q: %v\n", host.ConnectAddress(), count)
}
if err != nil && count == 0 {
if pool.session.cfg.ConvictionPolicy.AddFailure(err, host) {
pool.session.handleNodeDown(host.ConnectAddress(), port)
}
}
}
// connectMany creates new connections concurrent.
func (pool *hostConnPool) connectMany(count int) error {
if count == 0 {
return nil
}
var (
wg sync.WaitGroup
mu sync.Mutex
connectErr error
)
wg.Add(count)
for i := 0; i < count; i++ {
go func() {
defer wg.Done()
err := pool.connect()
pool.logConnectErr(err)
if err != nil {
mu.Lock()
connectErr = err
mu.Unlock()
}
}()
}
// wait for all connections are done
wg.Wait()
return connectErr
}
// create a new connection to the host and add it to the pool
func (pool *hostConnPool) connect() (err error) {
pool.mu.Lock()
shardID, nrShards := pool.connPicker.NextShard()
pool.mu.Unlock()
// TODO: provide a more robust connection retry mechanism, we should also
// be able to detect hosts that come up by trying to connect to downed ones.
// try to connect
var conn *Conn
reconnectionPolicy := pool.session.cfg.ReconnectionPolicy
for i := 0; i < reconnectionPolicy.GetMaxRetries(); i++ {
conn, err = pool.session.connectShard(pool.session.ctx, pool.host, pool, shardID, nrShards)
if err == nil {
break
}
if opErr, isOpErr := err.(*net.OpError); isOpErr {
// if the error is not a temporary error (ex: network unreachable) don't
// retry
if !opErr.Temporary() {
break
}
}
if gocqlDebug {
pool.logger.Printf("gocql: connection failed %q: %v, reconnecting with %T\n",
pool.host.ConnectAddress(), err, reconnectionPolicy)
}
time.Sleep(reconnectionPolicy.GetInterval(i))
}
if err != nil {
return err
}
if pool.keyspace != "" {
// set the keyspace
if err = conn.UseKeyspace(pool.keyspace); err != nil {
conn.Close()
return err
}
}
// add the Conn to the pool
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
conn.Close()
return nil
}
// lazily initialize the connPicker when we know the required type
pool.initConnPicker(conn)
pool.connPicker.Put(conn)
return nil
}
func (pool *hostConnPool) initConnPicker(conn *Conn) {
if _, ok := pool.connPicker.(nopConnPicker); !ok {
return
}
if conn.isScyllaConn() {
pool.connPicker = newScyllaConnPicker(conn)
return
}
pool.connPicker = newDefaultConnPicker(pool.size)
}
// handle any error from a Conn
func (pool *hostConnPool) HandleError(conn *Conn, err error, closed bool) {
if !closed {
// still an open connection, so continue using it
return
}
// TODO: track the number of errors per host and detect when a host is dead,
// then also have something which can detect when a host comes back.
pool.mu.Lock()
defer pool.mu.Unlock()
if pool.closed {
// pool closed
return
}
if gocqlDebug {
pool.logger.Printf("gocql: pool connection error %q: %v\n", conn.addr, err)
}
pool.connPicker.Remove(conn)
go pool.fill_debounce()
}

140
vendor/github.com/gocql/gocql/connpicker.go generated vendored Normal file
View File

@@ -0,0 +1,140 @@
package gocql
import (
"fmt"
"sync"
"sync/atomic"
)
type ConnPicker interface {
Pick(Token, ExecutableQuery) *Conn
Put(*Conn)
Remove(conn *Conn)
InFlight() int
Size() (int, int)
Close()
// NextShard returns the shardID to connect to.
// nrShard specifies how many shards the host has.
// If nrShards is zero, the caller shouldn't use shard-aware port.
NextShard() (shardID, nrShards int)
}
type defaultConnPicker struct {
conns []*Conn
pos uint32
size int
mu sync.RWMutex
}
func newDefaultConnPicker(size int) *defaultConnPicker {
if size <= 0 {
panic(fmt.Sprintf("invalid pool size %d", size))
}
return &defaultConnPicker{
size: size,
}
}
func (p *defaultConnPicker) Remove(conn *Conn) {
p.mu.Lock()
defer p.mu.Unlock()
for i, candidate := range p.conns {
if candidate == conn {
last := len(p.conns) - 1
p.conns[i], p.conns = p.conns[last], p.conns[:last]
break
}
}
}
func (p *defaultConnPicker) Close() {
p.mu.Lock()
defer p.mu.Unlock()
conns := p.conns
p.conns = nil
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}
func (p *defaultConnPicker) InFlight() int {
size := len(p.conns)
return size
}
func (p *defaultConnPicker) Size() (int, int) {
size := len(p.conns)
return size, p.size - size
}
func (p *defaultConnPicker) Pick(Token, ExecutableQuery) *Conn {
pos := int(atomic.AddUint32(&p.pos, 1) - 1)
size := len(p.conns)
var (
leastBusyConn *Conn
streamsAvailable int
)
// find the conn which has the most available streams, this is racy
for i := 0; i < size; i++ {
conn := p.conns[(pos+i)%size]
if conn == nil {
continue
}
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
return leastBusyConn
}
func (p *defaultConnPicker) Put(conn *Conn) {
p.mu.Lock()
p.conns = append(p.conns, conn)
p.mu.Unlock()
}
func (*defaultConnPicker) NextShard() (shardID, nrShards int) {
return 0, 0
}
// nopConnPicker is a no-operation implementation of ConnPicker, it's used when
// hostConnPool is created to allow deferring creation of the actual ConnPicker
// to the point where we have first connection.
type nopConnPicker struct{}
func (nopConnPicker) Pick(Token, ExecutableQuery) *Conn {
return nil
}
func (nopConnPicker) Put(*Conn) {
}
func (nopConnPicker) Remove(conn *Conn) {
}
func (nopConnPicker) InFlight() int {
return 0
}
func (nopConnPicker) Size() (int, int) {
// Return 1 to make hostConnPool to try to establish a connection.
// When first connection is established hostConnPool replaces nopConnPicker
// with a different ConnPicker implementation.
return 0, 1
}
func (nopConnPicker) Close() {
}
func (nopConnPicker) NextShard() (shardID, nrShards int) {
return 0, 0
}

517
vendor/github.com/gocql/gocql/control.go generated vendored Normal file
View File

@@ -0,0 +1,517 @@
package gocql
import (
"context"
crand "crypto/rand"
"errors"
"fmt"
"math/rand"
"net"
"os"
"regexp"
"strconv"
"sync"
"sync/atomic"
"time"
)
var (
randr *rand.Rand
mutRandr sync.Mutex
)
func init() {
b := make([]byte, 4)
if _, err := crand.Read(b); err != nil {
panic(fmt.Sprintf("unable to seed random number generator: %v", err))
}
randr = rand.New(rand.NewSource(int64(readInt(b))))
}
const (
controlConnStarting = 0
controlConnStarted = 1
controlConnClosing = -1
)
type controlConnection interface {
getConn() *connHost
awaitSchemaAgreement() error
query(statement string, values ...interface{}) (iter *Iter)
discoverProtocol(hosts []*HostInfo) (int, error)
connect(hosts []*HostInfo) error
close()
getSession() *Session
}
// Ensure that the atomic variable is aligned to a 64bit boundary
// so that atomic operations can be applied on 32bit architectures.
type controlConn struct {
state int32
reconnecting int32
session *Session
conn atomic.Value
retry RetryPolicy
quit chan struct{}
}
func (c *controlConn) getSession() *Session {
return c.session
}
func createControlConn(session *Session) *controlConn {
control := &controlConn{
session: session,
quit: make(chan struct{}),
retry: &SimpleRetryPolicy{NumRetries: 3},
}
control.conn.Store((*connHost)(nil))
return control
}
func (c *controlConn) heartBeat() {
if !atomic.CompareAndSwapInt32(&c.state, controlConnStarting, controlConnStarted) {
return
}
sleepTime := 1 * time.Second
timer := time.NewTimer(sleepTime)
defer timer.Stop()
for {
timer.Reset(sleepTime)
select {
case <-c.quit:
return
case <-timer.C:
}
resp, err := c.writeFrame(&writeOptionsFrame{})
if err != nil {
goto reconn
}
switch resp.(type) {
case *supportedFrame:
// Everything ok
sleepTime = 30 * time.Second
continue
case error:
goto reconn
default:
panic(fmt.Sprintf("gocql: unknown frame in response to options: %T", resp))
}
reconn:
// try to connect a bit faster
sleepTime = 1 * time.Second
c.reconnect()
continue
}
}
var hostLookupPreferV4 = os.Getenv("GOCQL_HOST_LOOKUP_PREFER_V4") == "true"
func hostInfo(addr string, defaultPort int) ([]*HostInfo, error) {
var port int
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
host = addr
port = defaultPort
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}
var hosts []*HostInfo
// Check if host is a literal IP address
if ip := net.ParseIP(host); ip != nil {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
return hosts, nil
}
// Look up host in DNS
ips, err := LookupIP(host)
if err != nil {
return nil, err
} else if len(ips) == 0 {
return nil, fmt.Errorf("no IP's returned from DNS lookup for %q", addr)
}
// Filter to v4 addresses if any present
if hostLookupPreferV4 {
var preferredIPs []net.IP
for _, v := range ips {
if v4 := v.To4(); v4 != nil {
preferredIPs = append(preferredIPs, v4)
}
}
if len(preferredIPs) != 0 {
ips = preferredIPs
}
}
for _, ip := range ips {
hosts = append(hosts, &HostInfo{hostname: host, connectAddress: ip, port: port})
}
return hosts, nil
}
func shuffleHosts(hosts []*HostInfo) []*HostInfo {
shuffled := make([]*HostInfo, len(hosts))
copy(shuffled, hosts)
mutRandr.Lock()
randr.Shuffle(len(hosts), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
mutRandr.Unlock()
return shuffled
}
// this is going to be version dependant and a nightmare to maintain :(
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
func parseProtocolFromError(err error) int {
// I really wish this had the actual info in the error frame...
matches := protocolSupportRe.FindAllStringSubmatch(err.Error(), -1)
if len(matches) != 1 || len(matches[0]) != 2 {
if verr, ok := err.(*protocolError); ok {
return int(verr.frame.Header().version.version())
}
return 0
}
max, err := strconv.Atoi(matches[0][1])
if err != nil {
return 0
}
return max
}
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
hosts = shuffleHosts(hosts)
connCfg := *c.session.connCfg
connCfg.ProtoVersion = 4 // TODO: define maxProtocol
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
// we should never get here, but if we do it means we connected to a
// host successfully which means our attempted protocol version worked
if !closed {
c.Close()
}
})
var err error
for _, host := range hosts {
var conn *Conn
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
if conn != nil {
conn.Close()
}
if err == nil {
return connCfg.ProtoVersion, nil
}
if proto := parseProtocolFromError(err); proto > 0 {
return proto, nil
}
}
return 0, err
}
func (c *controlConn) connect(hosts []*HostInfo) error {
if len(hosts) == 0 {
return errors.New("control: no endpoints specified")
}
// shuffle endpoints so not all drivers will connect to the same initial
// node.
hosts = shuffleHosts(hosts)
cfg := *c.session.connCfg
cfg.disableCoalesce = true
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.dial(c.session.ctx, host, &cfg, c)
if err != nil {
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
if conn == nil {
return fmt.Errorf("unable to connect to initial hosts: %v", err)
}
// we could fetch the initial ring here and update initial host data. So that
// when we return from here we have a ring topology ready to go.
go c.heartBeat()
return nil
}
type connHost struct {
conn ConnInterface
host *HostInfo
}
func (c *controlConn) setupConn(conn *Conn) error {
// we need up-to-date host info for the filterHost call below
iter := conn.querySystem(context.TODO(), qrySystemLocal)
defaultPort := 9042
if tcpAddr, ok := conn.conn.RemoteAddr().(*net.TCPAddr); ok {
defaultPort = tcpAddr.Port
}
host, err := hostInfoFromIter(iter, conn.host.connectAddress, defaultPort, c.session.cfg.translateAddressPort)
if err != nil {
return err
}
host = c.session.hostSource.addOrUpdate(host)
if c.session.cfg.filterHost(host) {
return fmt.Errorf("host was filtered: %v", host.ConnectAddress())
}
if err := c.registerEvents(conn); err != nil {
return fmt.Errorf("register events: %v", err)
}
ch := &connHost{
conn: conn,
host: host,
}
c.conn.Store(ch)
if c.session.initialized() {
// We connected to control conn, so add the connect the host in pool as well.
// Notify session we can start trying to connect to the node.
// We can't start the fill before the session is initialized, otherwise the fill would interfere
// with the fill called by Session.init. Session.init needs to wait for its fill to finish and that
// would return immediately if we started the fill here.
// TODO(martin-sucha): Trigger pool refill for all hosts, like in reconnectDownedHosts?
go c.session.startPoolFill(host)
}
return nil
}
func (c *controlConn) registerEvents(conn *Conn) error {
var events []string
if !c.session.cfg.Events.DisableTopologyEvents {
events = append(events, "TOPOLOGY_CHANGE")
}
if !c.session.cfg.Events.DisableNodeStatusEvents {
events = append(events, "STATUS_CHANGE")
}
if !c.session.cfg.Events.DisableSchemaEvents {
events = append(events, "SCHEMA_CHANGE")
}
if len(events) == 0 {
return nil
}
framer, err := conn.exec(context.Background(),
&writeRegisterFrame{
events: events,
}, nil)
if err != nil {
return err
}
frame, err := framer.parseFrame()
if err != nil {
return err
} else if _, ok := frame.(*readyFrame); !ok {
return fmt.Errorf("unexpected frame in response to register: got %T: %v\n", frame, frame)
}
return nil
}
func (c *controlConn) reconnect() {
if atomic.LoadInt32(&c.state) == controlConnClosing {
return
}
if !atomic.CompareAndSwapInt32(&c.reconnecting, 0, 1) {
return
}
defer atomic.StoreInt32(&c.reconnecting, 0)
conn, err := c.attemptReconnect()
if conn == nil {
c.session.logger.Printf("gocql: unable to reconnect control connection: %v\n", err)
return
}
err = c.session.refreshRingNow()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh ring: %v\n", err)
}
err = c.session.metadataDescriber.refreshAllSchema()
if err != nil {
c.session.logger.Printf("gocql: unable to refresh the schema: %v\n", err)
}
}
func (c *controlConn) attemptReconnect() (*Conn, error) {
hosts := c.session.hostSource.getHostsList()
hosts = shuffleHosts(hosts)
// keep the old behavior of connecting to the old host first by moving it to
// the front of the slice
ch := c.getConn()
if ch != nil {
for i := range hosts {
if hosts[i].Equal(ch.host) {
hosts[0], hosts[i] = hosts[i], hosts[0]
break
}
}
ch.conn.Close()
}
conn, err := c.attemptReconnectToAnyOfHosts(hosts)
if conn != nil {
return conn, err
}
c.session.logger.Printf("gocql: unable to connect to any ring node: %v\n", err)
c.session.logger.Printf("gocql: control falling back to initial contact points.\n")
// Fallback to initial contact points, as it may be the case that all known initialHosts
// changed their IPs while keeping the same hostname(s).
initialHosts, resolvErr := addrsToHosts(c.session.cfg.Hosts, c.session.cfg.Port, c.session.logger)
if resolvErr != nil {
return nil, fmt.Errorf("resolve contact points' hostnames: %v", resolvErr)
}
return c.attemptReconnectToAnyOfHosts(initialHosts)
}
func (c *controlConn) attemptReconnectToAnyOfHosts(hosts []*HostInfo) (*Conn, error) {
var conn *Conn
var err error
for _, host := range hosts {
conn, err = c.session.connect(c.session.ctx, host, c)
if err != nil {
if c.session.cfg.ConvictionPolicy.AddFailure(err, host) {
c.session.handleNodeDown(host.ConnectAddress(), host.Port())
}
c.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = c.setupConn(conn)
if err == nil {
break
}
c.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
return conn, err
}
func (c *controlConn) HandleError(conn *Conn, err error, closed bool) {
if !closed {
return
}
oldConn := c.getConn()
// If connection has long gone, and not been attempted for awhile,
// it's possible to have oldConn as nil here (#1297).
if oldConn != nil && oldConn.conn != conn {
return
}
c.reconnect()
}
func (c *controlConn) getConn() *connHost {
return c.conn.Load().(*connHost)
}
func (c *controlConn) writeFrame(w frameBuilder) (frame, error) {
ch := c.getConn()
if ch == nil {
return nil, errNoControl
}
framer, err := ch.conn.exec(context.Background(), w, nil)
if err != nil {
return nil, err
}
return framer.parseFrame()
}
// query will return nil if the connection is closed or nil
func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) {
q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil)
for {
ch := c.getConn()
q.conn = ch.conn.(*Conn)
iter = ch.conn.executeQuery(context.TODO(), q)
if gocqlDebug && iter.err != nil {
c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err)
}
q.AddAttempts(1, c.getConn().host)
if iter.err == nil || !c.retry.Attempt(q) {
break
}
}
return
}
func (c *controlConn) awaitSchemaAgreement() error {
ch := c.getConn()
return (&Iter{err: ch.conn.awaitSchemaAgreement(context.TODO())}).err
}
func (c *controlConn) close() {
if atomic.CompareAndSwapInt32(&c.state, controlConnStarted, controlConnClosing) {
c.quit <- struct{}{}
}
ch := c.getConn()
if ch != nil {
ch.conn.Close()
}
}
var errNoControl = errors.New("gocql: no control connection available")

11
vendor/github.com/gocql/gocql/cqltypes.go generated vendored Normal file
View File

@@ -0,0 +1,11 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
type Duration struct {
Months int32
Days int32
Nanoseconds int64
}

View File

@@ -0,0 +1,164 @@
package debounce
import (
"sync"
"time"
)
const (
RingRefreshDebounceTime = 1 * time.Second
)
// debounces requests to call a refresh function (currently used for ring refresh). It also supports triggering a refresh immediately.
type RefreshDebouncer struct {
mu sync.Mutex
stopped bool
broadcaster *errorBroadcaster
interval time.Duration
timer *time.Timer
refreshNowCh chan struct{}
quit chan struct{}
refreshFn func() error
}
func NewRefreshDebouncer(interval time.Duration, refreshFn func() error) *RefreshDebouncer {
d := &RefreshDebouncer{
stopped: false,
broadcaster: nil,
refreshNowCh: make(chan struct{}, 1),
quit: make(chan struct{}),
interval: interval,
timer: time.NewTimer(interval),
refreshFn: refreshFn,
}
d.timer.Stop()
go d.flusher()
return d
}
// debounces a request to call the refresh function
func (d *RefreshDebouncer) Debounce() {
d.mu.Lock()
defer d.mu.Unlock()
if d.stopped {
return
}
d.timer.Reset(d.interval)
}
// requests an immediate refresh which will cancel pending refresh requests
func (d *RefreshDebouncer) RefreshNow() <-chan error {
d.mu.Lock()
defer d.mu.Unlock()
if d.broadcaster == nil {
d.broadcaster = newErrorBroadcaster()
select {
case d.refreshNowCh <- struct{}{}:
default:
// already a refresh pending
}
}
return d.broadcaster.newListener()
}
func (d *RefreshDebouncer) flusher() {
for {
select {
case <-d.refreshNowCh:
case <-d.timer.C:
case <-d.quit:
}
d.mu.Lock()
if d.stopped {
if d.broadcaster != nil {
d.broadcaster.stop()
d.broadcaster = nil
}
d.timer.Stop()
d.mu.Unlock()
return
}
// make sure both request channels are cleared before we refresh
select {
case <-d.refreshNowCh:
default:
}
d.timer.Stop()
select {
case <-d.timer.C:
default:
}
curBroadcaster := d.broadcaster
d.broadcaster = nil
d.mu.Unlock()
err := d.refreshFn()
if curBroadcaster != nil {
curBroadcaster.broadcast(err)
}
}
}
func (d *RefreshDebouncer) Stop() {
d.mu.Lock()
if d.stopped {
d.mu.Unlock()
return
}
d.stopped = true
d.mu.Unlock()
d.quit <- struct{}{} // sync with flusher
close(d.quit)
}
// broadcasts an error to multiple channels (listeners)
type errorBroadcaster struct {
listeners []chan<- error
mu sync.Mutex
}
func newErrorBroadcaster() *errorBroadcaster {
return &errorBroadcaster{
listeners: nil,
mu: sync.Mutex{},
}
}
func (b *errorBroadcaster) newListener() <-chan error {
ch := make(chan error, 1)
b.mu.Lock()
defer b.mu.Unlock()
b.listeners = append(b.listeners, ch)
return ch
}
func (b *errorBroadcaster) broadcast(err error) {
b.mu.Lock()
defer b.mu.Unlock()
curListeners := b.listeners
if len(curListeners) > 0 {
b.listeners = nil
} else {
return
}
for _, listener := range curListeners {
listener <- err
close(listener)
}
}
func (b *errorBroadcaster) stop() {
b.mu.Lock()
defer b.mu.Unlock()
if len(b.listeners) == 0 {
return
}
for _, listener := range b.listeners {
close(listener)
}
b.listeners = nil
}

View File

@@ -0,0 +1,34 @@
package debounce
import (
"sync"
"sync/atomic"
)
// SimpleDebouncer is are tool for queuing immutable functions calls. It provides:
// 1. Blocking simultaneous calls
// 2. If there is no running call and no waiting call, then the current call go through
// 3. If there is running call and no waiting call, then the current call go waiting
// 4. If there is running call and waiting call, then the current call are voided
type SimpleDebouncer struct {
m sync.Mutex
count atomic.Int32
}
// NewSimpleDebouncer creates a new SimpleDebouncer.
func NewSimpleDebouncer() *SimpleDebouncer {
return &SimpleDebouncer{}
}
// Debounce attempts to execute the function if the logic of the SimpleDebouncer allows it.
func (d *SimpleDebouncer) Debounce(fn func()) bool {
if d.count.Add(1) > 2 {
d.count.Add(-1)
return false
}
d.m.Lock()
fn()
d.count.Add(-1)
d.m.Unlock()
return true
}

6
vendor/github.com/gocql/gocql/debug_off.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
//go:build !gocql_debug
// +build !gocql_debug
package gocql
const gocqlDebug = false

6
vendor/github.com/gocql/gocql/debug_on.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
//go:build gocql_debug
// +build gocql_debug
package gocql
const gocqlDebug = true

91
vendor/github.com/gocql/gocql/dial.go generated vendored Normal file
View File

@@ -0,0 +1,91 @@
package gocql
import (
"context"
"crypto/tls"
"fmt"
"net"
"strings"
)
// HostDialer allows customizing connection to cluster nodes.
type HostDialer interface {
// DialHost establishes a connection to the host.
// The returned connection must be directly usable for CQL protocol,
// specifically DialHost is responsible also for setting up the TLS session if needed.
// DialHost should disable write coalescing if the returned net.Conn does not support writev.
// As of Go 1.18, only plain TCP connections support writev, TLS sessions should disable coalescing.
// You can use WrapTLS helper function if you don't need to override the TLS setup.
DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error)
}
// DialedHost contains information about established connection to a host.
type DialedHost struct {
// Conn used to communicate with the server.
Conn net.Conn
// DisableCoalesce disables write coalescing for the Conn.
// If true, the effect is the same as if WriteCoalesceWaitTime was configured to 0.
DisableCoalesce bool
}
// defaultHostDialer dials host in a default way.
type defaultHostDialer struct {
dialer Dialer
tlsConfig *tls.Config
}
func (hd *defaultHostDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
connAddr := host.ConnectAddressAndPort()
conn, err := hd.dialer.DialContext(ctx, "tcp", connAddr)
if err != nil {
return nil, err
}
addr := host.HostnameAndPort()
return WrapTLS(ctx, conn, addr, hd.tlsConfig)
}
func tlsConfigForAddr(tlsConfig *tls.Config, addr string) *tls.Config {
// the TLS config is safe to be reused by connections but it must not
// be modified after being used.
if !tlsConfig.InsecureSkipVerify && tlsConfig.ServerName == "" {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]
// clone config to avoid modifying the shared one.
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = hostname
}
return tlsConfig
}
// WrapTLS optionally wraps a net.Conn connected to addr with the given tlsConfig.
// If the tlsConfig is nil, conn is not wrapped into a TLS session, so is insecure.
// If the tlsConfig does not have server name set, it is updated based on the default gocql rules.
func WrapTLS(ctx context.Context, conn net.Conn, addr string, tlsConfig *tls.Config) (*DialedHost, error) {
if tlsConfig != nil {
tlsConfig := tlsConfigForAddr(tlsConfig, addr)
tconn := tls.Client(conn, tlsConfig)
if err := tconn.HandshakeContext(ctx); err != nil {
conn.Close()
return nil, err
}
conn = tconn
}
return &DialedHost{
Conn: conn,
DisableCoalesce: tlsConfig != nil, // write coalescing can't use writev when the connection is wrapped.
}, nil
}

375
vendor/github.com/gocql/gocql/doc.go generated vendored Normal file
View File

@@ -0,0 +1,375 @@
// Copyright (c) 2012-2015 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package gocql implements a fast and robust Cassandra driver for the
// Go programming language.
//
// # Connecting to the cluster
//
// Pass a list of initial node IP addresses to NewCluster to create a new cluster configuration:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
//
// Port can be specified as part of the address, the above is equivalent to:
//
// cluster := gocql.NewCluster("192.168.1.1:9042", "192.168.1.2:9042", "192.168.1.3:9042")
//
// It is recommended to use the value set in the Cassandra config for broadcast_address or listen_address,
// an IP address not a domain name. This is because events from Cassandra will use the configured IP
// address, which is used to index connected hosts. If the domain name specified resolves to more than 1 IP address
// then the driver may connect multiple times to the same host, and will not mark the node being down or up from events.
//
// Then you can customize more options (see ClusterConfig):
//
// cluster.Keyspace = "example"
// cluster.Consistency = gocql.Quorum
// cluster.ProtoVersion = 4
//
// The driver tries to automatically detect the protocol version to use if not set, but you might want to set the
// protocol version explicitly, as it's not defined which version will be used in certain situations (for example
// during upgrade of the cluster when some of the nodes support different set of protocol versions than other nodes).
//
// The driver advertises the module name and version in the STARTUP message, so servers are able to detect the version.
// If you use replace directive in go.mod, the driver will send information about the replacement module instead.
//
// When ready, create a session from the configuration. Don't forget to Close the session once you are done with it:
//
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// # Authentication
//
// CQL protocol uses a SASL-based authentication mechanism and so consists of an exchange of server challenges and
// client response pairs. The details of the exchanged messages depend on the authenticator used.
//
// To use authentication, set ClusterConfig.Authenticator or ClusterConfig.AuthProvider.
//
// PasswordAuthenticator is provided to use for username/password authentication:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.Authenticator = gocql.PasswordAuthenticator{
// Username: "user",
// Password: "password"
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// By default, PasswordAuthenticator will attempt to authenticate regardless of what implementation the server returns
// in its AUTHENTICATE message as its authenticator, (e.g. org.apache.cassandra.auth.PasswordAuthenticator). If you
// wish to restrict this you may use PasswordAuthenticator.AllowedAuthenticators:
//
// cluster.Authenticator = gocql.PasswordAuthenticator {
// Username: "user",
// Password: "password"
// AllowedAuthenticators: []string{"org.apache.cassandra.auth.PasswordAuthenticator"},
// }
//
// # Transport layer security
//
// It is possible to secure traffic between the client and server with TLS.
//
// To use TLS, set the ClusterConfig.SslOpts field. SslOptions embeds *tls.Config so you can set that directly.
// There are also helpers to load keys/certificates from files.
//
// Warning: Due to historical reasons, the SslOptions is insecure by default, so you need to set EnableHostVerification
// to true if no Config is set. Most users should set SslOptions.Config to a *tls.Config.
// SslOptions and Config.InsecureSkipVerify interact as follows:
//
// Config.InsecureSkipVerify | EnableHostVerification | Result
// Config is nil | false | do not verify host
// Config is nil | true | verify host
// false | false | verify host
// true | false | do not verify host
// false | true | verify host
// true | true | verify host
//
// For example:
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.SslOpts = &gocql.SslOptions{
// EnableHostVerification: true,
// }
// session, err := cluster.CreateSession()
// if err != nil {
// return err
// }
// defer session.Close()
//
// # Data-center awareness and query routing
//
// To route queries to local DC first, use DCAwareRoundRobinPolicy. For example, if the datacenter you
// want to primarily connect is called dc1 (as configured in the database):
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.PoolConfig.HostSelectionPolicy = gocql.DCAwareRoundRobinPolicy("dc1")
//
// The driver can route queries to nodes that hold data replicas based on partition key (preferring local DC).
//
// cluster := gocql.NewCluster("192.168.1.1", "192.168.1.2", "192.168.1.3")
// cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy("dc1"))
//
// Note that TokenAwareHostPolicy can take options such as gocql.ShuffleReplicas and gocql.NonLocalReplicasFallback.
//
// We recommend running with a token aware host policy in production for maximum performance.
//
// The driver can only use token-aware routing for queries where all partition key columns are query parameters.
// For example, instead of
//
// session.Query("select value from mytable where pk1 = 'abc' AND pk2 = ?", "def")
//
// use
//
// session.Query("select value from mytable where pk1 = ? AND pk2 = ?", "abc", "def")
//
// # Rack-level awareness
//
// The DCAwareRoundRobinPolicy can be replaced with RackAwareRoundRobinPolicy, which takes two parameters, datacenter and rack.
//
// Instead of dividing hosts with two tiers (local datacenter and remote datacenters) it divides hosts into three
// (the local rack, the rest of the local datacenter, and everything else).
//
// RackAwareRoundRobinPolicy can be combined with TokenAwareHostPolicy in the same way as DCAwareRoundRobinPolicy.
//
// # Executing queries
//
// Create queries with Session.Query. Query values must not be reused between different executions and must not be
// modified after starting execution of the query.
//
// To execute a query without reading results, use Query.Exec:
//
// err := session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world").WithContext(ctx).Exec()
//
// Single row can be read by calling Query.Scan:
//
// err := session.Query(`SELECT id, text FROM tweet WHERE timeline = ? LIMIT 1`,
// "me").WithContext(ctx).Consistency(gocql.One).Scan(&id, &text)
//
// Multiple rows can be read using Iter.Scanner:
//
// scanner := session.Query(`SELECT id, text FROM tweet WHERE timeline = ?`,
// "me").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var (
// id gocql.UUID
// text string
// )
// err = scanner.Scan(&id, &text)
// if err != nil {
// log.Fatal(err)
// }
// fmt.Println("Tweet:", id, text)
// }
// // scanner.Err() closes the iterator, so scanner nor iter should be used afterwards.
// if err := scanner.Err(); err != nil {
// log.Fatal(err)
// }
//
// See Example for complete example.
//
// # Prepared statements
//
// The driver automatically prepares DML queries (SELECT/INSERT/UPDATE/DELETE/BATCH statements) and maintains a cache
// of prepared statements.
// CQL protocol does not support preparing other query types.
//
// When using CQL protocol >= 4, it is possible to use gocql.UnsetValue as the bound value of a column.
// This will cause the database to ignore writing the column.
// The main advantage is the ability to keep the same prepared statement even when you don't
// want to update some fields, where before you needed to make another prepared statement.
//
// # Executing multiple queries concurrently
//
// Session is safe to use from multiple goroutines, so to execute multiple concurrent queries, just execute them
// from several worker goroutines. Gocql provides synchronously-looking API (as recommended for Go APIs) and the queries
// are executed asynchronously at the protocol level.
//
// results := make(chan error, 2)
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 1").Exec()
// }()
// go func() {
// results <- session.Query(`INSERT INTO tweet (timeline, id, text) VALUES (?, ?, ?)`,
// "me", gocql.TimeUUID(), "hello world 2").Exec()
// }()
//
// # Nulls
//
// Null values are are unmarshalled as zero value of the type. If you need to distinguish for example between text
// column being null and empty string, you can unmarshal into *string variable instead of string.
//
// var text *string
// err := scanner.Scan(&text)
// if err != nil {
// // handle error
// }
// if text != nil {
// // not null
// }
// else {
// // null
// }
//
// See Example_nulls for full example.
//
// # Reusing slices
//
// The driver reuses backing memory of slices when unmarshalling. This is an optimization so that a buffer does not
// need to be allocated for every processed row. However, you need to be careful when storing the slices to other
// memory structures.
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// var myInts []int
// for scanner.Next() {
// // This scan reuses backing store of myInts for each row.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// When you want to save the data for later use, pass a new slice every time. A common pattern is to declare the
// slice variable within the scanner loop:
//
// scanner := session.Query(`SELECT myints FROM table WHERE pk = ?`, "key").WithContext(ctx).Iter().Scanner()
// for scanner.Next() {
// var myInts []int
// // This scan always gets pointer to fresh myInts slice, so does not reuse memory.
// err = scanner.Scan(&myInts)
// if err != nil {
// log.Fatal(err)
// }
// }
//
// # Paging
//
// The driver supports paging of results with automatic prefetch, see ClusterConfig.PageSize, Session.SetPrefetch,
// Query.PageSize, and Query.Prefetch.
//
// It is also possible to control the paging manually with Query.PageState (this disables automatic prefetch).
// Manual paging is useful if you want to store the page state externally, for example in a URL to allow users
// browse pages in a result. You might want to sign/encrypt the paging state when exposing it externally since
// it contains data from primary keys.
//
// Paging state is specific to the CQL protocol version and the exact query used. It is meant as opaque state that
// should not be modified. If you send paging state from different query or protocol version, then the behaviour
// is not defined (you might get unexpected results or an error from the server). For example, do not send paging state
// returned by node using protocol version 3 to a node using protocol version 4. Also, when using protocol version 4,
// paging state between Cassandra 2.2 and 3.0 is incompatible (https://issues.apache.org/jira/browse/CASSANDRA-10880).
//
// The driver does not check whether the paging state is from the same protocol version/statement.
// You might want to validate yourself as this could be a problem if you store paging state externally.
// For example, if you store paging state in a URL, the URLs might become broken when you upgrade your cluster.
//
// Call Query.PageState(nil) to fetch just the first page of the query results. Pass the page state returned by
// Iter.PageState to Query.PageState of a subsequent query to get the next page. If the length of slice returned
// by Iter.PageState is zero, there are no more pages available (or an error occurred).
//
// Using too low values of PageSize will negatively affect performance, a value below 100 is probably too low.
// While Cassandra returns exactly PageSize items (except for last page) in a page currently, the protocol authors
// explicitly reserved the right to return smaller or larger amount of items in a page for performance reasons, so don't
// rely on the page having the exact count of items.
//
// See Example_paging for an example of manual paging.
//
// # Dynamic list of columns
//
// There are certain situations when you don't know the list of columns in advance, mainly when the query is supplied
// by the user. Iter.Columns, Iter.RowData, Iter.MapScan and Iter.SliceMap can be used to handle this case.
//
// See Example_dynamicColumns.
//
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
// overhead to ensure this property.
// Unlogged batches don't have the overhead of logged batches, but don't guarantee atomicity.
// Updates of counters are handled specially by Cassandra so batches of counter updates have to use CounterBatch type.
// A counter batch can only contain statements to update counters.
//
// For unlogged batches it is recommended to send only single-partition batches (i.e. all statements in the batch should
// involve only a single partition).
// Multi-partition batch needs to be split by the coordinator node and re-sent to
// correct nodes.
// With single-partition batches you can send the batch directly to the node for the partition without incurring the
// additional network hop.
//
// It is also possible to pass entire BEGIN BATCH .. APPLY BATCH statement to Query.Exec.
// There are differences how those are executed.
// BEGIN BATCH statement passed to Query.Exec is prepared as a whole in a single statement.
// Session.ExecuteBatch prepares individual statements in the batch.
// If you have variable-length batches using the same statement, using Session.ExecuteBatch is more efficient.
//
// See Example_batch for an example.
//
// # Lightweight transactions
//
// Query.ScanCAS or Query.MapScanCAS can be used to execute a single-statement lightweight transaction (an
// INSERT/UPDATE .. IF statement) and reading its result. See example for Query.MapScanCAS.
//
// Multiple-statement lightweight transactions can be executed as a logged batch that contains at least one conditional
// statement. All the conditions must return true for the batch to be applied. You can use Session.ExecuteBatchCAS and
// Session.MapExecuteBatchCAS when executing the batch to learn about the result of the LWT. See example for
// Session.MapExecuteBatchCAS.
//
// # Retries and speculative execution
//
// Queries can be marked as idempotent. Marking the query as idempotent tells the driver that the query can be executed
// multiple times without affecting its result. Non-idempotent queries are not eligible for retrying nor speculative
// execution.
//
// Idempotent queries are retried in case of errors based on the configured RetryPolicy.
// If the query is LWT and the configured RetryPolicy additionally implements LWTRetryPolicy
// interface, then the policy will be cast to LWTRetryPolicy and used this way.
//
// Queries can be retried even before they fail by setting a SpeculativeExecutionPolicy. The policy can
// cause the driver to retry on a different node if the query is taking longer than a specified delay even before the
// driver receives an error or timeout from the server. When a query is speculatively executed, the original execution
// is still executing. The two parallel executions of the query race to return a result, the first received result will
// be returned.
//
// # User-defined types
//
// UDTs can be mapped (un)marshaled from/to map[string]interface{} a Go struct (or a type implementing
// UDTUnmarshaler, UDTMarshaler, Unmarshaler or Marshaler interfaces).
//
// For structs, cql tag can be used to specify the CQL field name to be mapped to a struct field:
//
// type MyUDT struct {
// FieldA int32 `cql:"a"`
// FieldB string `cql:"b"`
// }
//
// See Example_userDefinedTypesMap, Example_userDefinedTypesStruct, ExampleUDTMarshaler, ExampleUDTUnmarshaler.
//
// # Metrics and tracing
//
// It is possible to provide observer implementations that could be used to gather metrics:
//
// - QueryObserver for monitoring individual queries.
// - BatchObserver for monitoring batch queries.
// - ConnectObserver for monitoring new connections from the driver to the database.
// - FrameHeaderObserver for monitoring individual protocol frames.
//
// CQL protocol also supports tracing of queries. When enabled, the database will write information about
// internal events that happened during execution of the query. You can use Query.Trace to request tracing and receive
// the session ID that the database used to store the trace information in system_traces.sessions and
// system_traces.events tables. NewTraceWriter returns an implementation of Tracer that writes the events to a writer.
// Gathering trace information might be essential for debugging and optimizing queries, but writing traces has overhead,
// so this feature should not be used on production systems with very high load unless you know what you are doing.
// There is also a new implementation of Tracer - TracerEnhanced, that is intended to be more reliable and convinient to use.
// It has a funcionality to check if trace is ready to be extracted and only actually gets it if requested which makes
// the impact on a performance smaller.
package gocql // import "github.com/gocql/gocql"

90
vendor/github.com/gocql/gocql/docker-compose.yml generated vendored Normal file
View File

@@ -0,0 +1,90 @@
version: "3.7"
services:
node_1:
image: ${SCYLLA_IMAGE}
privileged: true
command: |
--smp 2
--memory 768M
--seeds 192.168.100.11
--overprovisioned 1
--experimental-features udf
--enable-user-defined-functions true
networks:
public:
ipv4_address: 192.168.100.11
volumes:
- /tmp/scylla:/var/lib/scylla/
- type: bind
source: ./testdata/config/scylla.yaml
target: /etc/scylla/scylla.yaml
- type: bind
source: ./testdata/pki/ca.crt
target: /etc/scylla/ca.crt
- type: bind
source: ./testdata/pki/cassandra.crt
target: /etc/scylla/db.crt
- type: bind
source: ./testdata/pki/cassandra.key
target: /etc/scylla/db.key
healthcheck:
test: [ "CMD", "cqlsh", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
node_2:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.12
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.12", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
node_3:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.13
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.13", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
depends_on:
node_2:
condition: service_healthy
node_4:
image: ${SCYLLA_IMAGE}
command: |
--smp 2
--memory 1G
--seeds 192.168.100.12
networks:
public:
ipv4_address: 192.168.100.14
healthcheck:
test: [ "CMD", "cqlsh", "192.168.100.14", "-e", "select * from system.local" ]
interval: 5s
timeout: 5s
retries: 18
depends_on:
node_3:
condition: service_healthy
networks:
public:
driver: bridge
ipam:
driver: default
config:
- subnet: 192.168.100.0/24

227
vendor/github.com/gocql/gocql/errors.go generated vendored Normal file
View File

@@ -0,0 +1,227 @@
package gocql
import "fmt"
// See CQL Binary Protocol v5, section 8 for more details.
// https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec
const (
// ErrCodeServer indicates unexpected error on server-side.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1246-L1247
ErrCodeServer = 0x0000
// ErrCodeProtocol indicates a protocol violation by some client message.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1248-L1250
ErrCodeProtocol = 0x000A
// ErrCodeCredentials indicates missing required authentication.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1251-L1254
ErrCodeCredentials = 0x0100
// ErrCodeUnavailable indicates unavailable error.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1255-L1265
ErrCodeUnavailable = 0x1000
// ErrCodeOverloaded returned in case of request on overloaded node coordinator.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1266-L1267
ErrCodeOverloaded = 0x1001
// ErrCodeBootstrapping returned from the coordinator node in bootstrapping phase.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1268-L1269
ErrCodeBootstrapping = 0x1002
// ErrCodeTruncate indicates truncation exception.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1270
ErrCodeTruncate = 0x1003
// ErrCodeWriteTimeout returned in case of timeout during the request write.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1271-L1304
ErrCodeWriteTimeout = 0x1100
// ErrCodeReadTimeout returned in case of timeout during the request read.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1305-L1321
ErrCodeReadTimeout = 0x1200
// ErrCodeReadFailure indicates request read error which is not covered by ErrCodeReadTimeout.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1322-L1340
ErrCodeReadFailure = 0x1300
// ErrCodeFunctionFailure indicates an error in user-defined function.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1341-L1347
ErrCodeFunctionFailure = 0x1400
// ErrCodeWriteFailure indicates request write error which is not covered by ErrCodeWriteTimeout.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1348-L1385
ErrCodeWriteFailure = 0x1500
// ErrCodeCDCWriteFailure is defined, but not yet documented in CQLv5 protocol.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1386
ErrCodeCDCWriteFailure = 0x1600
// ErrCodeCASWriteUnknown indicates only partially completed CAS operation.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
ErrCodeCASWriteUnknown = 0x1700
// ErrCodeSyntax indicates the syntax error in the query.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1399
ErrCodeSyntax = 0x2000
// ErrCodeUnauthorized indicates access rights violation by user on performed operation.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1400-L1401
ErrCodeUnauthorized = 0x2100
// ErrCodeInvalid indicates invalid query error which is not covered by ErrCodeSyntax.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1402
ErrCodeInvalid = 0x2200
// ErrCodeConfig indicates the configuration error.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1403
ErrCodeConfig = 0x2300
// ErrCodeAlreadyExists is returned for the requests creating the existing keyspace/table.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1404-L1413
ErrCodeAlreadyExists = 0x2400
// ErrCodeUnprepared returned from the host for prepared statement which is unknown.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1414-L1417
ErrCodeUnprepared = 0x2500
)
type RequestError interface {
Code() int
Message() string
Error() string
}
type errorFrame struct {
frameHeader
code int
message string
}
func (e errorFrame) Code() int {
return e.code
}
func (e errorFrame) Message() string {
return e.message
}
func (e errorFrame) Error() string {
return e.Message()
}
func (e errorFrame) String() string {
return fmt.Sprintf("[error code=%x message=%q]", e.code, e.message)
}
type RequestErrUnavailable struct {
errorFrame
Consistency Consistency
Required int
Alive int
}
func (e *RequestErrUnavailable) String() string {
return fmt.Sprintf("[request_error_unavailable consistency=%s required=%d alive=%d]", e.Consistency, e.Required, e.Alive)
}
type ErrorMap map[string]uint16
type RequestErrWriteTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
WriteType string
}
type RequestErrWriteFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
WriteType string
ErrorMap ErrorMap
}
type RequestErrCDCWriteFailure struct {
errorFrame
}
type RequestErrReadTimeout struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
DataPresent byte
}
type RequestErrAlreadyExists struct {
errorFrame
Keyspace string
Table string
}
type RequestErrUnprepared struct {
errorFrame
StatementId []byte
}
type RequestErrReadFailure struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
NumFailures int
DataPresent bool
ErrorMap ErrorMap
}
type RequestErrFunctionFailure struct {
errorFrame
Keyspace string
Function string
ArgTypes []string
}
// RequestErrCASWriteUnknown is distinct error for ErrCodeCasWriteUnknown.
//
// See https://github.com/apache/cassandra/blob/7337fc0/doc/native_protocol_v5.spec#L1387-L1397
type RequestErrCASWriteUnknown struct {
errorFrame
Consistency Consistency
Received int
BlockFor int
}
type UnknownServerError struct {
errorFrame
}
type OpType uint8
const (
OpTypeRead OpType = 0
OpTypeWrite OpType = 1
)
type RequestErrRateLimitReached struct {
errorFrame
OpType OpType
RejectedByCoordinator bool
}
func (e *RequestErrRateLimitReached) String() string {
var opType string
if e.OpType == OpTypeRead {
opType = "Read"
} else if e.OpType == OpTypeWrite {
opType = "Write"
} else {
opType = "Other"
}
return fmt.Sprintf("[request_error_rate_limit_reached OpType=%s RejectedByCoordinator=%t]", opType, e.RejectedByCoordinator)
}

256
vendor/github.com/gocql/gocql/events.go generated vendored Normal file
View File

@@ -0,0 +1,256 @@
package gocql
import (
"net"
"sync"
"time"
)
type eventDebouncer struct {
name string
timer *time.Timer
mu sync.Mutex
events []frame
callback func([]frame)
quit chan struct{}
logger StdLogger
}
func newEventDebouncer(name string, eventHandler func([]frame), logger StdLogger) *eventDebouncer {
e := &eventDebouncer{
name: name,
quit: make(chan struct{}),
timer: time.NewTimer(eventDebounceTime),
callback: eventHandler,
logger: logger,
}
e.timer.Stop()
go e.flusher()
return e
}
func (e *eventDebouncer) stop() {
e.quit <- struct{}{} // sync with flusher
close(e.quit)
}
func (e *eventDebouncer) flusher() {
for {
select {
case <-e.timer.C:
e.mu.Lock()
e.flush()
e.mu.Unlock()
case <-e.quit:
return
}
}
}
const (
eventBufferSize = 1000
eventDebounceTime = 1 * time.Second
)
// flush must be called with mu locked
func (e *eventDebouncer) flush() {
if len(e.events) == 0 {
return
}
// if the flush interval is faster than the callback then we will end up calling
// the callback multiple times, probably a bad idea. In this case we could drop
// frames?
go e.callback(e.events)
e.events = make([]frame, 0, eventBufferSize)
}
func (e *eventDebouncer) debounce(frame frame) {
e.mu.Lock()
e.timer.Reset(eventDebounceTime)
// TODO: probably need a warning to track if this threshold is too low
if len(e.events) < eventBufferSize {
e.events = append(e.events, frame)
} else {
e.logger.Printf("%s: buffer full, dropping event frame: %s", e.name, frame)
}
e.mu.Unlock()
}
func (s *Session) handleEvent(framer *framer) {
frame, err := framer.parseFrame()
if err != nil {
s.logger.Printf("gocql: unable to parse event frame: %v\n", err)
return
}
if gocqlDebug {
s.logger.Printf("gocql: handling frame: %v\n", frame)
}
switch f := frame.(type) {
case *schemaChangeKeyspace, *schemaChangeFunction,
*schemaChangeTable, *schemaChangeAggregate, *schemaChangeType:
s.schemaEvents.debounce(frame)
case *topologyChangeEventFrame, *statusChangeEventFrame:
s.nodeEvents.debounce(frame)
default:
s.logger.Printf("gocql: invalid event frame (%T): %v\n", f, f)
}
}
func (s *Session) handleSchemaEvent(frames []frame) {
// TODO: debounce events
for _, frame := range frames {
switch f := frame.(type) {
case *schemaChangeKeyspace:
s.metadataDescriber.clearSchema(f.keyspace)
s.handleKeyspaceChange(f.keyspace, f.change)
case *schemaChangeTable:
s.metadataDescriber.clearSchema(f.keyspace)
s.handleTableChange(f.keyspace, f.object, f.change)
case *schemaChangeAggregate:
s.metadataDescriber.clearSchema(f.keyspace)
case *schemaChangeFunction:
s.metadataDescriber.clearSchema(f.keyspace)
case *schemaChangeType:
s.metadataDescriber.clearSchema(f.keyspace)
}
}
}
func (s *Session) handleKeyspaceChange(keyspace, change string) {
s.control.awaitSchemaAgreement()
if change == "DROPPED" || change == "UPDATED" {
s.metadataDescriber.removeTabletsWithKeyspace(keyspace)
}
s.policy.KeyspaceChanged(KeyspaceUpdateEvent{Keyspace: keyspace, Change: change})
}
func (s *Session) handleTableChange(keyspace, table, change string) {
if change == "DROPPED" || change == "UPDATED" {
s.metadataDescriber.removeTabletsWithTable(keyspace, table)
}
}
// handleNodeEvent handles inbound status and topology change events.
//
// Status events are debounced by host IP; only the latest event is processed.
//
// Topology events are debounced by performing a single full topology refresh
// whenever any topology event comes in.
//
// Processing topology change events before status change events ensures
// that a NEW_NODE event is not dropped in favor of a newer UP event (which
// would itself be dropped/ignored, as the node is not yet known).
func (s *Session) handleNodeEvent(frames []frame) {
type nodeEvent struct {
change string
host net.IP
port int
}
topologyEventReceived := false
// status change events
sEvents := make(map[string]*nodeEvent)
for _, frame := range frames {
switch f := frame.(type) {
case *topologyChangeEventFrame:
topologyEventReceived = true
case *statusChangeEventFrame:
event, ok := sEvents[f.host.String()]
if !ok {
event = &nodeEvent{change: f.change, host: f.host, port: f.port}
sEvents[f.host.String()] = event
}
event.change = f.change
}
}
if topologyEventReceived && !s.cfg.Events.DisableTopologyEvents {
s.debounceRingRefresh()
}
for _, f := range sEvents {
if gocqlDebug {
s.logger.Printf("gocql: dispatching status change event: %+v\n", f)
}
// ignore events we received if they were disabled
// see https://github.com/gocql/gocql/issues/1591
switch f.change {
case "UP":
if !s.cfg.Events.DisableNodeStatusEvents {
s.handleNodeUp(f.host, f.port)
}
case "DOWN":
if !s.cfg.Events.DisableNodeStatusEvents {
s.handleNodeDown(f.host, f.port)
}
}
}
}
func (s *Session) handleNodeUp(eventIp net.IP, eventPort int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeUp: %s:%d\n", eventIp.String(), eventPort)
}
host, ok := s.hostSource.getHostByIP(eventIp.String())
if !ok {
s.debounceRingRefresh()
return
}
if s.cfg.filterHost(host) {
return
}
if d := host.Version().nodeUpDelay(); d > 0 {
time.Sleep(d)
}
s.startPoolFill(host)
}
func (s *Session) startPoolFill(host *HostInfo) {
// we let the pool call handleNodeConnected to change the host state
s.pool.addHost(host)
s.policy.AddHost(host)
}
func (s *Session) handleNodeConnected(host *HostInfo) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeConnected: %s:%d\n", host.ConnectAddress(), host.Port())
}
host.setState(NodeUp)
if !s.cfg.filterHost(host) {
s.policy.HostUp(host)
}
}
func (s *Session) handleNodeDown(ip net.IP, port int) {
if gocqlDebug {
s.logger.Printf("gocql: Session.handleNodeDown: %s:%d\n", ip.String(), port)
}
host, ok := s.hostSource.getHostByIP(ip.String())
if ok {
host.setState(NodeDown)
if s.cfg.filterHost(host) {
return
}
s.policy.HostDown(host)
hostID := host.HostID()
s.pool.removeHost(hostID)
}
}

110
vendor/github.com/gocql/gocql/exec.go generated vendored Normal file
View File

@@ -0,0 +1,110 @@
package gocql
import (
"fmt"
)
// SingleHostQueryExecutor allows to quickly execute diagnostic queries while
// connected to only a single node.
// The executor opens only a single connection to a node and does not use
// connection pools.
// Consistency level used is ONE.
// Retry policy is applied, attempts are visible in query metrics but query
// observer is not notified.
type SingleHostQueryExecutor struct {
session *Session
control *controlConn
}
// Exec executes the query without returning any rows.
func (e SingleHostQueryExecutor) Exec(stmt string, values ...interface{}) error {
return e.control.query(stmt, values...).Close()
}
// Iter executes the query and returns an iterator capable of iterating
// over all results.
func (e SingleHostQueryExecutor) Iter(stmt string, values ...interface{}) *Iter {
return e.control.query(stmt, values...)
}
func (e SingleHostQueryExecutor) Close() {
if e.control != nil {
e.control.close()
}
if e.session != nil {
e.session.Close()
}
}
// NewSingleHostQueryExecutor creates a SingleHostQueryExecutor by connecting
// to one of the hosts specified in the ClusterConfig.
// If ProtoVersion is not specified version 4 is used.
// Caller is responsible for closing the executor after use.
func NewSingleHostQueryExecutor(cfg *ClusterConfig) (e SingleHostQueryExecutor, err error) {
// Check that hosts in the ClusterConfig is not empty
if len(cfg.Hosts) < 1 {
err = ErrNoHosts
return
}
c := *cfg
// If protocol version not set assume 4 and skip discovery
if c.ProtoVersion == 0 {
c.ProtoVersion = 4
}
// Close in case of error
defer func() {
if err != nil {
e.Close()
}
}()
// Create uninitialised session
c.disableInit = true
if e.session, err = NewSession(c); err != nil {
err = fmt.Errorf("new session: %w", err)
return
}
var hosts []*HostInfo
if hosts, err = addrsToHosts(c.Hosts, c.Port, c.Logger); err != nil {
err = fmt.Errorf("addrs to hosts: %w", err)
return
}
// Create control connection to one of the hosts
e.control = createControlConn(e.session)
// shuffle endpoints so not all drivers will connect to the same initial
// node.
hosts = shuffleHosts(hosts)
conncfg := *e.control.session.connCfg
conncfg.disableCoalesce = true
var conn *Conn
for _, host := range hosts {
conn, err = e.control.session.dial(e.control.session.ctx, host, &conncfg, e.control)
if err != nil {
e.control.session.logger.Printf("gocql: unable to dial control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
continue
}
err = e.control.setupConn(conn)
if err == nil {
break
}
e.control.session.logger.Printf("gocql: unable setup control conn %v:%v: %v\n", host.ConnectAddress(), host.Port(), err)
conn.Close()
conn = nil
}
if conn == nil {
err = fmt.Errorf("setup: %w", err)
return
}
return
}

57
vendor/github.com/gocql/gocql/filters.go generated vendored Normal file
View File

@@ -0,0 +1,57 @@
package gocql
import "fmt"
// HostFilter interface is used when a host is discovered via server sent events.
type HostFilter interface {
// Called when a new host is discovered, returning true will cause the host
// to be added to the pools.
Accept(host *HostInfo) bool
}
// HostFilterFunc converts a func(host HostInfo) bool into a HostFilter
type HostFilterFunc func(host *HostInfo) bool
func (fn HostFilterFunc) Accept(host *HostInfo) bool {
return fn(host)
}
// AcceptAllFilter will accept all hosts
func AcceptAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return true
})
}
func DenyAllFilter() HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return false
})
}
// DataCentreHostFilter filters all hosts such that they are in the same data centre
// as the supplied data centre.
func DataCentreHostFilter(dataCentre string) HostFilter {
return HostFilterFunc(func(host *HostInfo) bool {
return host.DataCenter() == dataCentre
})
}
// WhiteListHostFilter filters incoming hosts by checking that their address is
// in the initial hosts whitelist.
func WhiteListHostFilter(hosts ...string) HostFilter {
hostInfos, err := addrsToHosts(hosts, 9042, nopLogger{})
if err != nil {
// dont want to panic here, but rather not break the API
panic(fmt.Errorf("unable to lookup host info from address: %v", err))
}
m := make(map[string]bool, len(hostInfos))
for _, host := range hostInfos {
m[host.ConnectAddress().String()] = true
}
return HostFilterFunc(func(host *HostInfo) bool {
return m[host.ConnectAddress().String()]
})
}

2119
vendor/github.com/gocql/gocql/frame.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

34
vendor/github.com/gocql/gocql/fuzz.go generated vendored Normal file
View File

@@ -0,0 +1,34 @@
//go:build gofuzz
// +build gofuzz
package gocql
import "bytes"
func Fuzz(data []byte) int {
var bw bytes.Buffer
r := bytes.NewReader(data)
head, err := readHeader(r, make([]byte, 9))
if err != nil {
return 0
}
framer := newFramer(r, &bw, nil, byte(head.version))
err = framer.readFrame(&head)
if err != nil {
return 0
}
frame, err := framer.parseFrame()
if err != nil {
return 0
}
if frame != nil {
return 1
}
return 2
}

448
vendor/github.com/gocql/gocql/helpers.go generated vendored Normal file
View File

@@ -0,0 +1,448 @@
// Copyright (c) 2012 The gocql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gocql
import (
"fmt"
"math/big"
"net"
"reflect"
"strings"
"time"
"gopkg.in/inf.v0"
)
type RowData struct {
Columns []string
Values []interface{}
}
func goType(t TypeInfo) (reflect.Type, error) {
switch t.Type() {
case TypeVarchar, TypeAscii, TypeInet, TypeText:
return reflect.TypeOf(*new(string)), nil
case TypeBigInt, TypeCounter:
return reflect.TypeOf(*new(int64)), nil
case TypeTime:
return reflect.TypeOf(*new(time.Duration)), nil
case TypeTimestamp:
return reflect.TypeOf(*new(time.Time)), nil
case TypeBlob:
return reflect.TypeOf(*new([]byte)), nil
case TypeBoolean:
return reflect.TypeOf(*new(bool)), nil
case TypeFloat:
return reflect.TypeOf(*new(float32)), nil
case TypeDouble:
return reflect.TypeOf(*new(float64)), nil
case TypeInt:
return reflect.TypeOf(*new(int)), nil
case TypeSmallInt:
return reflect.TypeOf(*new(int16)), nil
case TypeTinyInt:
return reflect.TypeOf(*new(int8)), nil
case TypeDecimal:
return reflect.TypeOf(*new(*inf.Dec)), nil
case TypeUUID, TypeTimeUUID:
return reflect.TypeOf(*new(UUID)), nil
case TypeList, TypeSet:
elemType, err := goType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.SliceOf(elemType), nil
case TypeMap:
keyType, err := goType(t.(CollectionType).Key)
if err != nil {
return nil, err
}
valueType, err := goType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.MapOf(keyType, valueType), nil
case TypeVarint:
return reflect.TypeOf(*new(*big.Int)), nil
case TypeTuple:
// what can we do here? all there is to do is to make a list of interface{}
tuple := t.(TupleTypeInfo)
return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil
case TypeUDT:
return reflect.TypeOf(make(map[string]interface{})), nil
case TypeDate:
return reflect.TypeOf(*new(time.Time)), nil
case TypeDuration:
return reflect.TypeOf(*new(Duration)), nil
default:
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
}
}
func dereference(i interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(i)).Interface()
}
func getCassandraBaseType(name string) Type {
switch name {
case "ascii":
return TypeAscii
case "bigint":
return TypeBigInt
case "blob":
return TypeBlob
case "boolean":
return TypeBoolean
case "counter":
return TypeCounter
case "date":
return TypeDate
case "decimal":
return TypeDecimal
case "double":
return TypeDouble
case "duration":
return TypeDuration
case "float":
return TypeFloat
case "int":
return TypeInt
case "smallint":
return TypeSmallInt
case "tinyint":
return TypeTinyInt
case "time":
return TypeTime
case "timestamp":
return TypeTimestamp
case "uuid":
return TypeUUID
case "varchar":
return TypeVarchar
case "text":
return TypeText
case "varint":
return TypeVarint
case "timeuuid":
return TypeTimeUUID
case "inet":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
default:
return TypeCustom
}
}
func getCassandraType(name string, logger StdLogger) TypeInfo {
if strings.HasPrefix(name, "frozen<") {
return getCassandraType(strings.TrimPrefix(name[:len(name)-1], "frozen<"), logger)
} else if strings.HasPrefix(name, "set<") {
return CollectionType{
NativeType: NativeType{typ: TypeSet},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "set<"), logger),
}
} else if strings.HasPrefix(name, "list<") {
return CollectionType{
NativeType: NativeType{typ: TypeList},
Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger),
}
} else if strings.HasPrefix(name, "map<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<"))
if len(names) != 2 {
logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names))
return NativeType{
typ: TypeCustom,
}
}
return CollectionType{
NativeType: NativeType{typ: TypeMap},
Key: getCassandraType(names[0], logger),
Elem: getCassandraType(names[1], logger),
}
} else if strings.HasPrefix(name, "tuple<") {
names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<"))
types := make([]TypeInfo, len(names))
for i, name := range names {
types[i] = getCassandraType(name, logger)
}
return TupleTypeInfo{
NativeType: NativeType{typ: TypeTuple},
Elems: types,
}
} else {
return NativeType{
typ: getCassandraBaseType(name),
}
}
}
func splitCompositeTypes(name string) []string {
if !strings.Contains(name, "<") {
return strings.Split(name, ", ")
}
var parts []string
lessCount := 0
segment := ""
for _, char := range name {
if char == ',' && lessCount == 0 {
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
segment = ""
continue
}
segment += string(char)
if char == '<' {
lessCount++
} else if char == '>' {
lessCount--
}
}
if segment != "" {
parts = append(parts, strings.TrimSpace(segment))
}
return parts
}
func apacheToCassandraType(t string) string {
t = strings.Replace(t, apacheCassandraTypePrefix, "", -1)
t = strings.Replace(t, "(", "<", -1)
t = strings.Replace(t, ")", ">", -1)
types := strings.FieldsFunc(t, func(r rune) bool {
return r == '<' || r == '>' || r == ','
})
for _, typ := range types {
t = strings.Replace(t, typ, getApacheCassandraType(typ).String(), -1)
}
// This is done so it exactly matches what Cassandra returns
return strings.Replace(t, ",", ", ", -1)
}
func getApacheCassandraType(class string) Type {
switch strings.TrimPrefix(class, apacheCassandraTypePrefix) {
case "AsciiType":
return TypeAscii
case "LongType":
return TypeBigInt
case "BytesType":
return TypeBlob
case "BooleanType":
return TypeBoolean
case "CounterColumnType":
return TypeCounter
case "DecimalType":
return TypeDecimal
case "DoubleType":
return TypeDouble
case "FloatType":
return TypeFloat
case "Int32Type":
return TypeInt
case "ShortType":
return TypeSmallInt
case "ByteType":
return TypeTinyInt
case "TimeType":
return TypeTime
case "DateType", "TimestampType":
return TypeTimestamp
case "UUIDType", "LexicalUUIDType":
return TypeUUID
case "UTF8Type":
return TypeVarchar
case "IntegerType":
return TypeVarint
case "TimeUUIDType":
return TypeTimeUUID
case "InetAddressType":
return TypeInet
case "MapType":
return TypeMap
case "ListType":
return TypeList
case "SetType":
return TypeSet
case "TupleType":
return TypeTuple
case "DurationType":
return TypeDuration
default:
return TypeCustom
}
}
func (r *RowData) rowMap(m map[string]interface{}) {
for i, column := range r.Columns {
val := dereference(r.Values[i])
if valVal := reflect.ValueOf(val); valVal.Kind() == reflect.Slice {
valCopy := reflect.MakeSlice(valVal.Type(), valVal.Len(), valVal.Cap())
reflect.Copy(valCopy, valVal)
m[column] = valCopy.Interface()
} else {
m[column] = val
}
}
}
// TupeColumnName will return the column name of a tuple value in a column named
// c at index n. It should be used if a specific element within a tuple is needed
// to be extracted from a map returned from SliceMap or MapScan.
func TupleColumnName(c string, n int) string {
return fmt.Sprintf("%s[%d]", c, n)
}
func (iter *Iter) RowData() (RowData, error) {
if iter.err != nil {
return RowData{}, iter.err
}
columns := make([]string, 0, len(iter.Columns()))
values := make([]interface{}, 0, len(iter.Columns()))
for _, column := range iter.Columns() {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val, err := column.TypeInfo.NewWithError()
if err != nil {
return RowData{}, err
}
columns = append(columns, column.Name)
values = append(values, val)
} else {
for i, elem := range c.Elems {
columns = append(columns, TupleColumnName(column.Name, i))
val, err := elem.NewWithError()
if err != nil {
return RowData{}, err
}
values = append(values, val)
}
}
}
rowData := RowData{
Columns: columns,
Values: values,
}
return rowData, nil
}
// TODO(zariel): is it worth exporting this?
func (iter *Iter) rowMap() (map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
rowData, _ := iter.RowData()
iter.Scan(rowData.Values...)
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
return m, nil
}
// SliceMap is a helper function to make the API easier to use
// returns the data from the query in the form of []map[string]interface{}
func (iter *Iter) SliceMap() ([]map[string]interface{}, error) {
if iter.err != nil {
return nil, iter.err
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
dataToReturn := make([]map[string]interface{}, 0)
for iter.Scan(rowData.Values...) {
m := make(map[string]interface{}, len(rowData.Columns))
rowData.rowMap(m)
dataToReturn = append(dataToReturn, m)
}
if iter.err != nil {
return nil, iter.err
}
return dataToReturn, nil
}
// MapScan takes a map[string]interface{} and populates it with a row
// that is returned from cassandra.
//
// Each call to MapScan() must be called with a new map object.
// During the call to MapScan() any pointers in the existing map
// are replaced with non pointer types before the call returns
//
// iter := session.Query(`SELECT * FROM mytable`).Iter()
// for {
// // New map each iteration
// row := make(map[string]interface{})
// if !iter.MapScan(row) {
// break
// }
// // Do things with row
// if fullname, ok := row["fullname"]; ok {
// fmt.Printf("Full Name: %s\n", fullname)
// }
// }
//
// You can also pass pointers in the map before each call
//
// var fullName FullName // Implements gocql.Unmarshaler and gocql.Marshaler interfaces
// var address net.IP
// var age int
// iter := session.Query(`SELECT * FROM scan_map_table`).Iter()
// for {
// // New map each iteration
// row := map[string]interface{}{
// "fullname": &fullName,
// "age": &age,
// "address": &address,
// }
// if !iter.MapScan(row) {
// break
// }
// fmt.Printf("First: %s Age: %d Address: %q\n", fullName.FirstName, age, address)
// }
func (iter *Iter) MapScan(m map[string]interface{}) bool {
if iter.err != nil {
return false
}
// Not checking for the error because we just did
rowData, _ := iter.RowData()
for i, col := range rowData.Columns {
if dest, ok := m[col]; ok {
rowData.Values[i] = dest
}
}
if iter.Scan(rowData.Values...) {
rowData.rowMap(m)
return true
}
return false
}
func copyBytes(p []byte) []byte {
b := make([]byte, len(p))
copy(b, p)
return b
}
var failDNS = false
func LookupIP(host string) ([]net.IP, error) {
if failDNS {
return nil, &net.DNSError{}
}
return net.LookupIP(host)
}

722
vendor/github.com/gocql/gocql/host_source.go generated vendored Normal file
View File

@@ -0,0 +1,722 @@
package gocql
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
)
var ErrCannotFindHost = errors.New("cannot find host")
var ErrHostAlreadyExists = errors.New("host already exists")
type nodeState int32
func (n nodeState) String() string {
if n == NodeUp {
return "UP"
} else if n == NodeDown {
return "DOWN"
}
return fmt.Sprintf("UNKNOWN_%d", n)
}
const (
NodeUp nodeState = iota
NodeDown
)
type cassVersion struct {
Major, Minor, Patch int
}
func (c *cassVersion) Set(v string) error {
if v == "" {
return nil
}
return c.UnmarshalCQL(nil, []byte(v))
}
func (c *cassVersion) UnmarshalCQL(info TypeInfo, data []byte) error {
return c.unmarshal(data)
}
func (c *cassVersion) unmarshal(data []byte) error {
version := strings.TrimSuffix(string(data), "-SNAPSHOT")
version = strings.TrimPrefix(version, "v")
v := strings.Split(version, ".")
if len(v) < 2 {
return fmt.Errorf("invalid version string: %s", data)
}
var err error
c.Major, err = strconv.Atoi(v[0])
if err != nil {
return fmt.Errorf("invalid major version %v: %v", v[0], err)
}
c.Minor, err = strconv.Atoi(v[1])
if err != nil {
return fmt.Errorf("invalid minor version %v: %v", v[1], err)
}
if len(v) > 2 {
c.Patch, err = strconv.Atoi(v[2])
if err != nil {
return fmt.Errorf("invalid patch version %v: %v", v[2], err)
}
}
return nil
}
func (c cassVersion) Before(major, minor, patch int) bool {
// We're comparing us (cassVersion) with the provided version (major, minor, patch)
// We return true if our version is lower (comes before) than the provided one.
if c.Major < major {
return true
} else if c.Major == major {
if c.Minor < minor {
return true
} else if c.Minor == minor && c.Patch < patch {
return true
}
}
return false
}
func (c cassVersion) AtLeast(major, minor, patch int) bool {
return !c.Before(major, minor, patch)
}
func (c cassVersion) String() string {
return fmt.Sprintf("v%d.%d.%d", c.Major, c.Minor, c.Patch)
}
func (c cassVersion) nodeUpDelay() time.Duration {
if c.Major >= 2 && c.Minor >= 2 {
// CASSANDRA-8236
return 0
}
return 10 * time.Second
}
type HostInfo struct {
// TODO(zariel): reduce locking maybe, not all values will change, but to ensure
// that we are thread safe use a mutex to access all fields.
mu sync.RWMutex
hostname string
peer net.IP
broadcastAddress net.IP
listenAddress net.IP
rpcAddress net.IP
preferredIP net.IP
connectAddress net.IP
untranslatedConnectAddress net.IP
port int
dataCenter string
rack string
hostId string
workload string
graph bool
dseVersion string
partitioner string
clusterName string
version cassVersion
state nodeState
schemaVersion string
tokens []string
scyllaShardAwarePort uint16
scyllaShardAwarePortTLS uint16
}
func (h *HostInfo) Equal(host *HostInfo) bool {
if h == host {
// prevent rlock reentry
return true
}
return h.HostID() == host.HostID() && h.ConnectAddressAndPort() == host.ConnectAddressAndPort()
}
func (h *HostInfo) Peer() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.peer
}
func (h *HostInfo) invalidConnectAddr() bool {
h.mu.RLock()
defer h.mu.RUnlock()
addr, _ := h.connectAddressLocked()
return !validIpAddr(addr)
}
func validIpAddr(addr net.IP) bool {
return addr != nil && !addr.IsUnspecified()
}
func (h *HostInfo) connectAddressLocked() (net.IP, string) {
if validIpAddr(h.connectAddress) {
return h.connectAddress, "connect_address"
} else if validIpAddr(h.rpcAddress) {
return h.rpcAddress, "rpc_adress"
} else if validIpAddr(h.preferredIP) {
// where does perferred_ip get set?
return h.preferredIP, "preferred_ip"
} else if validIpAddr(h.broadcastAddress) {
return h.broadcastAddress, "broadcast_address"
} else if validIpAddr(h.peer) {
return h.peer, "peer"
}
return net.IPv4zero, "invalid"
}
// nodeToNodeAddress returns address broadcasted between node to nodes.
// It's either `broadcast_address` if host info is read from system.local or `peer` if read from system.peers.
// This IP address is also part of CQL Event emitted on topology/status changes,
// but does not uniquely identify the node in case multiple nodes use the same IP address.
func (h *HostInfo) nodeToNodeAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if validIpAddr(h.broadcastAddress) {
return h.broadcastAddress
} else if validIpAddr(h.peer) {
return h.peer
}
return net.IPv4zero
}
// Returns the address that should be used to connect to the host.
// If you wish to override this, use an AddressTranslator or
// use a HostFilter to SetConnectAddress()
func (h *HostInfo) ConnectAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
return addr
}
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
}
func (h *HostInfo) UntranslatedConnectAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
if len(h.untranslatedConnectAddress) != 0 {
return h.untranslatedConnectAddress
}
if addr, _ := h.connectAddressLocked(); validIpAddr(addr) {
return addr
}
panic(fmt.Sprintf("no valid connect address for host: %v. Is your cluster configured correctly?", h))
}
func (h *HostInfo) SetConnectAddress(address net.IP) *HostInfo {
// TODO(zariel): should this not be exported?
h.mu.Lock()
defer h.mu.Unlock()
h.connectAddress = address
return h
}
func (h *HostInfo) BroadcastAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.broadcastAddress
}
func (h *HostInfo) ListenAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.listenAddress
}
func (h *HostInfo) RPCAddress() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.rpcAddress
}
func (h *HostInfo) PreferredIP() net.IP {
h.mu.RLock()
defer h.mu.RUnlock()
return h.preferredIP
}
func (h *HostInfo) DataCenter() string {
h.mu.RLock()
dc := h.dataCenter
h.mu.RUnlock()
return dc
}
func (h *HostInfo) Rack() string {
h.mu.RLock()
rack := h.rack
h.mu.RUnlock()
return rack
}
func (h *HostInfo) HostID() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.hostId
}
func (h *HostInfo) SetHostID(hostID string) {
h.mu.Lock()
defer h.mu.Unlock()
h.hostId = hostID
}
func (h *HostInfo) WorkLoad() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.workload
}
func (h *HostInfo) Graph() bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.graph
}
func (h *HostInfo) DSEVersion() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.dseVersion
}
func (h *HostInfo) Partitioner() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.partitioner
}
func (h *HostInfo) ClusterName() string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.clusterName
}
func (h *HostInfo) Version() cassVersion {
h.mu.RLock()
defer h.mu.RUnlock()
return h.version
}
func (h *HostInfo) State() nodeState {
h.mu.RLock()
defer h.mu.RUnlock()
return h.state
}
func (h *HostInfo) setState(state nodeState) *HostInfo {
h.mu.Lock()
defer h.mu.Unlock()
h.state = state
return h
}
func (h *HostInfo) Tokens() []string {
h.mu.RLock()
defer h.mu.RUnlock()
return h.tokens
}
func (h *HostInfo) Port() int {
h.mu.RLock()
defer h.mu.RUnlock()
return h.port
}
func (h *HostInfo) update(from *HostInfo) {
if h == from {
return
}
h.mu.Lock()
defer h.mu.Unlock()
from.mu.RLock()
defer from.mu.RUnlock()
// autogenerated do not update
if h.peer == nil {
h.peer = from.peer
}
if h.broadcastAddress == nil {
h.broadcastAddress = from.broadcastAddress
}
if h.listenAddress == nil {
h.listenAddress = from.listenAddress
}
if h.rpcAddress == nil {
h.rpcAddress = from.rpcAddress
}
if h.preferredIP == nil {
h.preferredIP = from.preferredIP
}
if h.connectAddress == nil {
h.connectAddress = from.connectAddress
}
if h.port == 0 {
h.port = from.port
}
if h.dataCenter == "" {
h.dataCenter = from.dataCenter
}
if h.rack == "" {
h.rack = from.rack
}
if h.hostId == "" {
h.hostId = from.hostId
}
if h.workload == "" {
h.workload = from.workload
}
if h.dseVersion == "" {
h.dseVersion = from.dseVersion
}
if h.partitioner == "" {
h.partitioner = from.partitioner
}
if h.clusterName == "" {
h.clusterName = from.clusterName
}
if h.version == (cassVersion{}) {
h.version = from.version
}
if h.tokens == nil {
h.tokens = from.tokens
}
}
func (h *HostInfo) IsUp() bool {
return h != nil && h.State() == NodeUp
}
func (h *HostInfo) IsBusy(s *Session) bool {
pool, ok := s.pool.getPool(h)
return ok && h != nil && pool.InFlight() >= MAX_IN_FLIGHT_THRESHOLD
}
func (h *HostInfo) HostnameAndPort() string {
// Fast path: in most cases hostname is not empty
var (
hostname string
port int
)
h.mu.RLock()
hostname = h.hostname
port = h.port
h.mu.RUnlock()
if hostname != "" {
return net.JoinHostPort(hostname, strconv.Itoa(port))
}
// Slow path: hostname is empty
h.mu.Lock()
defer h.mu.Unlock()
if h.hostname == "" { // recheck is hostname empty
// if yes - fill it
addr, _ := h.connectAddressLocked()
h.hostname = addr.String()
}
return net.JoinHostPort(h.hostname, strconv.Itoa(h.port))
}
func (h *HostInfo) Hostname() string {
// Fast path: in most cases hostname is not empty
var hostname string
h.mu.RLock()
hostname = h.hostname
h.mu.RUnlock()
if hostname != "" {
return hostname
}
// Slow path: hostname is empty
h.mu.Lock()
defer h.mu.Unlock()
if h.hostname == "" {
addr, _ := h.connectAddressLocked()
h.hostname = addr.String()
}
return h.hostname
}
func (h *HostInfo) ConnectAddressAndPort() string {
h.mu.Lock()
defer h.mu.Unlock()
addr, _ := h.connectAddressLocked()
return net.JoinHostPort(addr.String(), strconv.Itoa(h.port))
}
func (h *HostInfo) String() string {
h.mu.RLock()
defer h.mu.RUnlock()
connectAddr, source := h.connectAddressLocked()
return fmt.Sprintf("[HostInfo hostname=%q connectAddress=%q peer=%q rpc_address=%q broadcast_address=%q "+
"preferred_ip=%q connect_addr=%q connect_addr_source=%q "+
"port=%d data_centre=%q rack=%q host_id=%q version=%q state=%s num_tokens=%d]",
h.hostname, h.connectAddress, h.peer, h.rpcAddress, h.broadcastAddress, h.preferredIP,
connectAddr, source,
h.port, h.dataCenter, h.rack, h.hostId, h.version, h.state, len(h.tokens))
}
func (h *HostInfo) setScyllaSupported(s scyllaSupported) {
h.mu.Lock()
defer h.mu.Unlock()
h.scyllaShardAwarePort = s.shardAwarePort
h.scyllaShardAwarePortTLS = s.shardAwarePortSSL
}
// ScyllaShardAwarePort returns the shard aware port of this host.
// Returns zero if the shard aware port is not known.
func (h *HostInfo) ScyllaShardAwarePort() uint16 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.scyllaShardAwarePort
}
// ScyllaShardAwarePortTLS returns the TLS-enabled shard aware port of this host.
// Returns zero if the shard aware port is not known.
func (h *HostInfo) ScyllaShardAwarePortTLS() uint16 {
h.mu.RLock()
defer h.mu.RUnlock()
return h.scyllaShardAwarePortTLS
}
// Returns true if we are using system_schema.keyspaces instead of system.schema_keyspaces
func checkSystemSchema(control controlConnection) (bool, error) {
iter := control.query("SELECT * FROM system_schema.keyspaces" + control.getSession().usingTimeoutClause)
if err := iter.err; err != nil {
if errf, ok := err.(*errorFrame); ok {
if errf.code == ErrCodeSyntax {
return false, nil
}
}
return false, err
}
return true, nil
}
// Given a map that represents a row from either system.local or system.peers
// return as much information as we can in *HostInfo
func hostInfoFromMap(row map[string]interface{}, host *HostInfo, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
const assertErrorMsg = "Assertion failed for %s"
var ok bool
// Default to our connected port if the cluster doesn't have port information
for key, value := range row {
switch key {
case "data_center":
host.dataCenter, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "data_center")
}
case "rack":
host.rack, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rack")
}
case "host_id":
hostId, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "host_id")
}
host.hostId = hostId.String()
case "release_version":
version, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "release_version")
}
host.version.Set(version)
case "peer":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "peer")
}
host.peer = net.ParseIP(ip)
case "cluster_name":
host.clusterName, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "cluster_name")
}
case "partitioner":
host.partitioner, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "partitioner")
}
case "broadcast_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "broadcast_address")
}
host.broadcastAddress = net.ParseIP(ip)
case "preferred_ip":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "preferred_ip")
}
host.preferredIP = net.ParseIP(ip)
case "rpc_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "rpc_address")
}
host.rpcAddress = net.ParseIP(ip)
case "native_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "native_address")
}
host.rpcAddress = net.ParseIP(ip)
case "listen_address":
ip, ok := value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "listen_address")
}
host.listenAddress = net.ParseIP(ip)
case "native_port":
native_port, ok := value.(int)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "native_port")
}
host.port = native_port
case "workload":
host.workload, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "workload")
}
case "graph":
host.graph, ok = value.(bool)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "graph")
}
case "tokens":
host.tokens, ok = value.([]string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "tokens")
}
case "dse_version":
host.dseVersion, ok = value.(string)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "dse_version")
}
case "schema_version":
schemaVersion, ok := value.(UUID)
if !ok {
return nil, fmt.Errorf(assertErrorMsg, "schema_version")
}
host.schemaVersion = schemaVersion.String()
}
// TODO(thrawn01): Add 'port'? once CASSANDRA-7544 is complete
// Not sure what the port field will be called until the JIRA issue is complete
}
host.untranslatedConnectAddress = host.ConnectAddress()
ip, port := translateAddressPort(host.untranslatedConnectAddress, host.port)
host.connectAddress = ip
host.port = port
return host, nil
}
func hostInfoFromIter(iter *Iter, connectAddress net.IP, defaultPort int, translateAddressPort func(addr net.IP, port int) (net.IP, int)) (*HostInfo, error) {
rows, err := iter.SliceMap()
if err != nil {
// TODO(zariel): make typed error
return nil, err
}
if len(rows) == 0 {
return nil, errors.New("query returned 0 rows")
}
host, err := hostInfoFromMap(rows[0], &HostInfo{connectAddress: connectAddress, port: defaultPort}, translateAddressPort)
if err != nil {
return nil, err
}
return host, nil
}
// debounceRingRefresh submits a ring refresh request to the ring refresh debouncer.
func (s *Session) debounceRingRefresh() {
s.ringRefresher.Debounce()
}
// refreshRing executes a ring refresh immediately and cancels pending debounce ring refresh requests.
func (s *Session) refreshRingNow() error {
err, ok := <-s.ringRefresher.RefreshNow()
if !ok {
return errors.New("could not refresh ring because stop was requested")
}
return err
}
func (s *Session) refreshRing() error {
hosts, partitioner, err := s.hostSource.GetHostsFromSystem()
if err != nil {
return err
}
prevHosts := s.hostSource.getHostsMap()
for _, h := range hosts {
if s.cfg.filterHost(h) {
continue
}
if host, ok := s.hostSource.addHostIfMissing(h); !ok {
s.startPoolFill(h)
} else {
// host (by hostID) already exists; determine if IP has changed
newHostID := h.HostID()
existing, ok := prevHosts[newHostID]
if !ok {
return fmt.Errorf("get existing host=%s from prevHosts: %w", h, ErrCannotFindHost)
}
if h.connectAddress.Equal(existing.connectAddress) && h.nodeToNodeAddress().Equal(existing.nodeToNodeAddress()) {
// no host IP change
host.update(h)
} else {
// host IP has changed
// remove old HostInfo (w/old IP)
s.removeHost(existing)
if _, alreadyExists := s.hostSource.addHostIfMissing(h); alreadyExists {
return fmt.Errorf("add new host=%s after removal: %w", h, ErrHostAlreadyExists)
}
// add new HostInfo (same hostID, new IP)
s.startPoolFill(h)
}
}
delete(prevHosts, h.HostID())
}
for _, host := range prevHosts {
s.metadataDescriber.removeTabletsWithHost(host)
s.removeHost(host)
}
s.policy.SetPartitioner(partitioner)
return nil
}

46
vendor/github.com/gocql/gocql/host_source_gen.go generated vendored Normal file
View File

@@ -0,0 +1,46 @@
//go:build genhostinfo
// +build genhostinfo
package main
import (
"fmt"
"reflect"
"sync"
"github.com/gocql/gocql"
)
func gen(clause, field string) {
fmt.Printf("if h.%s == %s {\n", field, clause)
fmt.Printf("\th.%s = from.%s\n", field, field)
fmt.Println("}")
}
func main() {
t := reflect.ValueOf(&gocql.HostInfo{}).Elem().Type()
mu := reflect.TypeOf(sync.RWMutex{})
for i := 0; i < t.NumField(); i++ {
f := t.Field(i)
if f.Type == mu {
continue
}
switch f.Type.Kind() {
case reflect.Slice:
gen("nil", f.Name)
case reflect.String:
gen(`""`, f.Name)
case reflect.Int:
gen("0", f.Name)
case reflect.Struct:
gen("("+f.Type.Name()+"{})", f.Name)
case reflect.Bool, reflect.Int32:
continue
default:
panic(fmt.Sprintf("unknown field: %s", f))
}
}
}

7
vendor/github.com/gocql/gocql/host_source_scylla.go generated vendored Normal file
View File

@@ -0,0 +1,7 @@
package gocql
func (h *HostInfo) SetDatacenter(dc string) {
h.mu.Lock()
defer h.mu.Unlock()
h.dataCenter = dc
}

0
vendor/github.com/gocql/gocql/install_test_deps.sh generated vendored Normal file
View File

54
vendor/github.com/gocql/gocql/integration.sh generated vendored Normal file
View File

@@ -0,0 +1,54 @@
#!/bin/bash
#
# Copyright (C) 2017 ScyllaDB
#
readonly SCYLLA_IMAGE=${SCYLLA_IMAGE}
set -eu -o pipefail
function scylla_up() {
local -r exec="docker compose exec -T"
echo "==> Running Scylla ${SCYLLA_IMAGE}"
docker pull ${SCYLLA_IMAGE}
docker compose up -d --wait || ( docker compose ps --format json | jq -M 'select(.Health == "unhealthy") | .Service' | xargs docker compose logs; exit 1 )
}
function scylla_down() {
echo "==> Stopping Scylla"
docker compose down
}
function scylla_restart() {
scylla_down
scylla_up
}
scylla_restart
sudo chmod 0777 /tmp/scylla/cql.m
readonly clusterSize=1
readonly multiNodeClusterSize=3
readonly scylla_liveset="192.168.100.11"
readonly scylla_tablet_liveset="192.168.100.12"
readonly cversion="3.11.4"
readonly proto=4
readonly args="-gocql.timeout=60s -proto=${proto} -rf=${clusterSize} -clusterSize=${clusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -cluster=${scylla_liveset}"
readonly tabletArgs="-gocql.timeout=60s -proto=${proto} -rf=1 -clusterSize=${multiNodeClusterSize} -autowait=2000ms -compressor=snappy -gocql.cversion=${cversion} -multiCluster=${scylla_tablet_liveset}"
if [[ "$*" == *"tablet"* ]];
then
echo "==> Running tablet tests with args: ${tabletArgs}"
go test -timeout=5m -race -tags="tablet" ${tabletArgs} ./...
fi
TAGS=$*
TAGS=${TAGS//"tablet"/}
if [ ! -z "$TAGS" ];
then
echo "==> Running ${TAGS} tests with args: ${args}"
go test -timeout=5m -race -tags="$TAGS" ${args} ./...
fi

127
vendor/github.com/gocql/gocql/internal/lru/lru.go generated vendored Normal file
View File

@@ -0,0 +1,127 @@
/*
Copyright 2015 To gocql authors
Copyright 2013 Google Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
// Package lru implements an LRU cache.
package lru
import "container/list"
// Cache is an LRU cache. It is not safe for concurrent access.
//
// This cache has been forked from github.com/golang/groupcache/lru, but
// specialized with string keys to avoid the allocations caused by wrapping them
// in interface{}.
type Cache struct {
// MaxEntries is the maximum number of cache entries before
// an item is evicted. Zero means no limit.
MaxEntries int
// OnEvicted optionally specifies a callback function to be
// executed when an entry is purged from the cache.
OnEvicted func(key string, value interface{})
ll *list.List
cache map[string]*list.Element
}
type entry struct {
key string
value interface{}
}
// New creates a new Cache.
// If maxEntries is zero, the cache has no limit and it's assumed
// that eviction is done by the caller.
func New(maxEntries int) *Cache {
return &Cache{
MaxEntries: maxEntries,
ll: list.New(),
cache: make(map[string]*list.Element),
}
}
// Add adds a value to the cache.
func (c *Cache) Add(key string, value interface{}) {
if c.cache == nil {
c.cache = make(map[string]*list.Element)
c.ll = list.New()
}
if ee, ok := c.cache[key]; ok {
c.ll.MoveToFront(ee)
ee.Value.(*entry).value = value
return
}
ele := c.ll.PushFront(&entry{key, value})
c.cache[key] = ele
if c.MaxEntries != 0 && c.ll.Len() > c.MaxEntries {
c.RemoveOldest()
}
}
// Get looks up a key's value from the cache.
func (c *Cache) Get(key string) (value interface{}, ok bool) {
if c.cache == nil {
return
}
if ele, hit := c.cache[key]; hit {
c.ll.MoveToFront(ele)
return ele.Value.(*entry).value, true
}
return
}
// Remove removes the provided key from the cache.
func (c *Cache) Remove(key string) bool {
if c.cache == nil {
return false
}
if ele, hit := c.cache[key]; hit {
c.removeElement(ele)
return true
}
return false
}
// RemoveOldest removes the oldest item from the cache.
func (c *Cache) RemoveOldest() {
if c.cache == nil {
return
}
ele := c.ll.Back()
if ele != nil {
c.removeElement(ele)
}
}
func (c *Cache) removeElement(e *list.Element) {
c.ll.Remove(e)
kv := e.Value.(*entry)
delete(c.cache, kv.key)
if c.OnEvicted != nil {
c.OnEvicted(kv.key, kv.value)
}
}
// Len returns the number of items in the cache.
func (c *Cache) Len() int {
if c.cache == nil {
return 0
}
return c.ll.Len()
}

135
vendor/github.com/gocql/gocql/internal/murmur/murmur.go generated vendored Normal file
View File

@@ -0,0 +1,135 @@
package murmur
const (
c1 int64 = -8663945395140668459 // 0x87c37b91114253d5
c2 int64 = 5545529020109919103 // 0x4cf5ad432745937f
fmix1 int64 = -49064778989728563 // 0xff51afd7ed558ccd
fmix2 int64 = -4265267296055464877 // 0xc4ceb9fe1a85ec53
)
func fmix(n int64) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
n ^= int64(uint64(n) >> 33)
n *= fmix1
n ^= int64(uint64(n) >> 33)
n *= fmix2
n ^= int64(uint64(n) >> 33)
return n
}
func block(p byte) int64 {
return int64(int8(p))
}
func rotl(x int64, r uint8) int64 {
// cast to unsigned for logical right bitshift (to match C* MM3 implementation)
return (x << r) | (int64)((uint64(x) >> (64 - r)))
}
func Murmur3H1(data []byte) int64 {
length := len(data)
var h1, h2, k1, k2 int64
// body
nBlocks := length / 16
for i := 0; i < nBlocks; i++ {
k1, k2 = getBlock(data, i)
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
h1 = rotl(h1, 27)
h1 += h2
h1 = h1*5 + 0x52dce729
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
h2 = rotl(h2, 31)
h2 += h1
h2 = h2*5 + 0x38495ab5
}
// tail
tail := data[nBlocks*16:]
k1 = 0
k2 = 0
switch length & 15 {
case 15:
k2 ^= block(tail[14]) << 48
fallthrough
case 14:
k2 ^= block(tail[13]) << 40
fallthrough
case 13:
k2 ^= block(tail[12]) << 32
fallthrough
case 12:
k2 ^= block(tail[11]) << 24
fallthrough
case 11:
k2 ^= block(tail[10]) << 16
fallthrough
case 10:
k2 ^= block(tail[9]) << 8
fallthrough
case 9:
k2 ^= block(tail[8])
k2 *= c2
k2 = rotl(k2, 33)
k2 *= c1
h2 ^= k2
fallthrough
case 8:
k1 ^= block(tail[7]) << 56
fallthrough
case 7:
k1 ^= block(tail[6]) << 48
fallthrough
case 6:
k1 ^= block(tail[5]) << 40
fallthrough
case 5:
k1 ^= block(tail[4]) << 32
fallthrough
case 4:
k1 ^= block(tail[3]) << 24
fallthrough
case 3:
k1 ^= block(tail[2]) << 16
fallthrough
case 2:
k1 ^= block(tail[1]) << 8
fallthrough
case 1:
k1 ^= block(tail[0])
k1 *= c1
k1 = rotl(k1, 31)
k1 *= c2
h1 ^= k1
}
h1 ^= int64(length)
h2 ^= int64(length)
h1 += h2
h2 += h1
h1 = fmix(h1)
h2 = fmix(h2)
h1 += h2
// the following is extraneous since h2 is discarded
// h2 += h1
return h1
}

View File

@@ -0,0 +1,12 @@
//go:build appengine || s390x
// +build appengine s390x
package murmur
import "encoding/binary"
func getBlock(data []byte, n int) (int64, int64) {
k1 := int64(binary.LittleEndian.Uint64(data[n*16:]))
k2 := int64(binary.LittleEndian.Uint64(data[(n*16)+8:]))
return k1, k2
}

View File

@@ -0,0 +1,16 @@
//go:build !appengine && !s390x
// +build !appengine,!s390x
package murmur
import (
"unsafe"
)
func getBlock(data []byte, n int) (int64, int64) {
block := (*[2]int64)(unsafe.Pointer(&data[n*16]))
k1 := block[0]
k2 := block[1]
return k1, k2
}

View File

@@ -0,0 +1,146 @@
package streams
import (
"math"
"strconv"
"sync/atomic"
)
const bucketBits = 64
// IDGenerator tracks and allocates streams which are in use.
type IDGenerator struct {
NumStreams int
inuseStreams int32
numBuckets uint32
// streams is a bitset where each bit represents a stream, a 1 implies in use
streams []uint64
offset uint32
}
func New(protocol int) *IDGenerator {
maxStreams := 128
if protocol > 2 {
maxStreams = 32768
}
return NewLimited(maxStreams)
}
func NewLimited(maxStreams int) *IDGenerator {
// Round up maxStreams to a nearest
// multiple of 64
maxStreams = ((maxStreams + 63) / 64) * 64
buckets := maxStreams / 64
// reserve stream 0
streams := make([]uint64, buckets)
streams[0] = 1 << 63
return &IDGenerator{
NumStreams: maxStreams,
streams: streams,
numBuckets: uint32(buckets),
offset: uint32(buckets) - 1,
}
}
func streamFromBucket(bucket, streamInBucket int) int {
return (bucket * bucketBits) + streamInBucket
}
func (s *IDGenerator) GetStream() (int, bool) {
// Reduce collisions by offsetting the starting point
offset := atomic.AddUint32(&s.offset, 1)
for i := uint32(0); i < s.numBuckets; i++ {
pos := int((i + offset) % s.numBuckets)
bucket := atomic.LoadUint64(&s.streams[pos])
if bucket == math.MaxUint64 {
// all streams in use
continue
}
for j := 0; j < bucketBits; j++ {
mask := uint64(1 << streamOffset(j))
for bucket&mask == 0 {
if atomic.CompareAndSwapUint64(&s.streams[pos], bucket, bucket|mask) {
atomic.AddInt32(&s.inuseStreams, 1)
return streamFromBucket(int(pos), j), true
}
bucket = atomic.LoadUint64(&s.streams[pos])
}
}
}
return 0, false
}
func bitfmt(b uint64) string {
return strconv.FormatUint(b, 16)
}
// returns the bucket offset of a given stream
func bucketOffset(i int) int {
return i / bucketBits
}
func streamOffset(stream int) uint64 {
return bucketBits - uint64(stream%bucketBits) - 1
}
func isSet(bits uint64, stream int) bool {
return bits>>streamOffset(stream)&1 == 1
}
func (s *IDGenerator) isSet(stream int) bool {
bits := atomic.LoadUint64(&s.streams[bucketOffset(stream)])
return isSet(bits, stream)
}
func (s *IDGenerator) String() string {
size := s.numBuckets * (bucketBits + 1)
buf := make([]byte, 0, size)
for i := 0; i < int(s.numBuckets); i++ {
bits := atomic.LoadUint64(&s.streams[i])
buf = append(buf, bitfmt(bits)...)
buf = append(buf, ' ')
}
return string(buf[: size-1 : size-1])
}
func (s *IDGenerator) Clear(stream int) (inuse bool) {
offset := bucketOffset(stream)
bucket := atomic.LoadUint64(&s.streams[offset])
mask := uint64(1) << streamOffset(stream)
if bucket&mask != mask {
// already cleared
return false
}
for !atomic.CompareAndSwapUint64(&s.streams[offset], bucket, bucket & ^mask) {
bucket = atomic.LoadUint64(&s.streams[offset])
if bucket&mask != mask {
// already cleared
return false
}
}
// TODO: make this account for 0 stream being reserved
if atomic.AddInt32(&s.inuseStreams, -1) < 0 {
// TODO(zariel): remove this
panic("negative streams inuse")
}
return true
}
func (s *IDGenerator) Available() int {
return s.NumStreams - int(atomic.LoadInt32(&s.inuseStreams)) - 1
}
func (s *IDGenerator) InUse() int {
return int(atomic.LoadInt32(&s.inuseStreams))
}

40
vendor/github.com/gocql/gocql/logger.go generated vendored Normal file
View File

@@ -0,0 +1,40 @@
package gocql
import (
"bytes"
"fmt"
"log"
)
type StdLogger interface {
Print(v ...interface{})
Printf(format string, v ...interface{})
Println(v ...interface{})
}
type nopLogger struct{}
func (n nopLogger) Print(_ ...interface{}) {}
func (n nopLogger) Printf(_ string, _ ...interface{}) {}
func (n nopLogger) Println(_ ...interface{}) {}
type testLogger struct {
capture bytes.Buffer
}
func (l *testLogger) Print(v ...interface{}) { fmt.Fprint(&l.capture, v...) }
func (l *testLogger) Printf(format string, v ...interface{}) { fmt.Fprintf(&l.capture, format, v...) }
func (l *testLogger) Println(v ...interface{}) { fmt.Fprintln(&l.capture, v...) }
func (l *testLogger) String() string { return l.capture.String() }
type defaultLogger struct{}
func (l *defaultLogger) Print(v ...interface{}) { log.Print(v...) }
func (l *defaultLogger) Printf(format string, v ...interface{}) { log.Printf(format, v...) }
func (l *defaultLogger) Println(v ...interface{}) { log.Println(v...) }
// Logger for logging messages.
// Deprecated: Use ClusterConfig.Logger instead.
var Logger StdLogger = &defaultLogger{}

1846
vendor/github.com/gocql/gocql/marshal.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1466
vendor/github.com/gocql/gocql/metadata_cassandra.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1102
vendor/github.com/gocql/gocql/metadata_scylla.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

1326
vendor/github.com/gocql/gocql/policies.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

78
vendor/github.com/gocql/gocql/prepared_cache.go generated vendored Normal file
View File

@@ -0,0 +1,78 @@
package gocql
import (
"bytes"
"sync"
"github.com/gocql/gocql/internal/lru"
)
const defaultMaxPreparedStmts = 1000
// preparedLRU is the prepared statement cache
type preparedLRU struct {
mu sync.Mutex
lru *lru.Cache
}
func (p *preparedLRU) clear() {
p.mu.Lock()
defer p.mu.Unlock()
for p.lru.Len() > 0 {
p.lru.RemoveOldest()
}
}
func (p *preparedLRU) add(key string, val *inflightPrepare) {
p.mu.Lock()
defer p.mu.Unlock()
p.lru.Add(key, val)
}
func (p *preparedLRU) remove(key string) bool {
p.mu.Lock()
defer p.mu.Unlock()
return p.lru.Remove(key)
}
func (p *preparedLRU) execIfMissing(key string, fn func(lru *lru.Cache) *inflightPrepare) (*inflightPrepare, bool) {
p.mu.Lock()
defer p.mu.Unlock()
val, ok := p.lru.Get(key)
if ok {
return val.(*inflightPrepare), true
}
return fn(p.lru), false
}
func (p *preparedLRU) keyFor(hostID, keyspace, statement string) string {
// TODO: we should just use a struct for the key in the map
return hostID + keyspace + statement
}
func (p *preparedLRU) evictPreparedID(key string, id []byte) {
p.mu.Lock()
defer p.mu.Unlock()
val, ok := p.lru.Get(key)
if !ok {
return
}
ifp, ok := val.(*inflightPrepare)
if !ok {
return
}
select {
case <-ifp.done:
if bytes.Equal(id, ifp.preparedStatment.id) {
p.lru.Remove(key)
}
default:
}
}

238
vendor/github.com/gocql/gocql/query_executor.go generated vendored Normal file
View File

@@ -0,0 +1,238 @@
package gocql
import (
"context"
"errors"
"sync"
"time"
)
type ExecutableQuery interface {
borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine.
releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error.
execute(ctx context.Context, conn *Conn) *Iter
attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo)
retryPolicy() RetryPolicy
speculativeExecutionPolicy() SpeculativeExecutionPolicy
GetRoutingKey() ([]byte, error)
Keyspace() string
Table() string
IsIdempotent() bool
IsLWT() bool
GetCustomPartitioner() Partitioner
withContext(context.Context) ExecutableQuery
RetryableQuery
GetSession() *Session
}
type queryExecutor struct {
pool *policyConnPool
policy HostSelectionPolicy
}
func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter {
start := time.Now()
iter := qry.execute(ctx, conn)
end := time.Now()
qry.attempt(q.pool.keyspace, end, start, iter, conn.host)
return iter
}
func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp SpeculativeExecutionPolicy,
hostIter NextHost, results chan *Iter) *Iter {
ticker := time.NewTicker(sp.Delay())
defer ticker.Stop()
for i := 0; i < sp.Attempts(); i++ {
select {
case <-ticker.C:
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
go q.run(ctx, qry, hostIter, results)
case <-ctx.Done():
return &Iter{err: ctx.Err()}
case iter := <-results:
return iter
}
}
return nil
}
func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) {
hostIter := q.policy.Pick(qry)
// check if the query is not marked as idempotent, if
// it is, we force the policy to NonSpeculative
sp := qry.speculativeExecutionPolicy()
if !qry.IsIdempotent() || sp.Attempts() == 0 {
return q.do(qry.Context(), qry, hostIter), nil
}
// When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below.
// To ensure we don't call it concurrently, we wrap the returned NextHost function here to synchronize access to it.
var mu sync.Mutex
origHostIter := hostIter
hostIter = func() SelectedHost {
mu.Lock()
defer mu.Unlock()
return origHostIter()
}
ctx, cancel := context.WithCancel(qry.Context())
defer cancel()
results := make(chan *Iter, 1)
// Launch the main execution
qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release().
go q.run(ctx, qry, hostIter, results)
// The speculative executions are launched _in addition_ to the main
// execution, on a timer. So Speculation{2} would make 3 executions running
// in total.
if iter := q.speculate(ctx, qry, sp, hostIter, results); iter != nil {
return iter, nil
}
select {
case iter := <-results:
return iter, nil
case <-ctx.Done():
return &Iter{err: ctx.Err()}, nil
}
}
func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter {
rt := qry.retryPolicy()
if rt == nil {
rt = &SimpleRetryPolicy{3}
}
lwtRT, isRTSupportsLWT := rt.(LWTRetryPolicy)
var getShouldRetry func(qry RetryableQuery) bool
var getRetryType func(error) RetryType
if isRTSupportsLWT && qry.IsLWT() {
getShouldRetry = lwtRT.AttemptLWT
getRetryType = lwtRT.GetRetryTypeLWT
} else {
getShouldRetry = rt.Attempt
getRetryType = rt.GetRetryType
}
var potentiallyExecuted bool
execute := func(qry ExecutableQuery, selectedHost SelectedHost) (iter *Iter, retry RetryType) {
host := selectedHost.Info()
if host == nil || !host.IsUp() {
return &Iter{
err: &QueryError{
err: ErrHostDown,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
pool, ok := q.pool.getPool(host)
if !ok {
return &Iter{
err: &QueryError{
err: ErrNoPool,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
conn := pool.Pick(selectedHost.Token(), qry)
if conn == nil {
return &Iter{
err: &QueryError{
err: ErrNoConnectionsInPool,
potentiallyExecuted: potentiallyExecuted,
},
}, RetryNextHost
}
iter = q.attemptQuery(ctx, qry, conn)
iter.host = selectedHost.Info()
// Update host
if iter.err == nil {
return iter, RetryType(255)
}
switch {
case errors.Is(iter.err, context.Canceled),
errors.Is(iter.err, context.DeadlineExceeded):
selectedHost.Mark(nil)
potentiallyExecuted = true
retry = Rethrow
default:
selectedHost.Mark(iter.err)
retry = RetryType(255) // Don't enforce retry and get it from retry policy
}
var qErr *QueryError
if errors.As(iter.err, &qErr) {
potentiallyExecuted = potentiallyExecuted && qErr.PotentiallyExecuted()
qErr.potentiallyExecuted = potentiallyExecuted
qErr.isIdempotent = qry.IsIdempotent()
iter.err = qErr
} else {
iter.err = &QueryError{
err: iter.err,
potentiallyExecuted: potentiallyExecuted,
isIdempotent: qry.IsIdempotent(),
}
}
return iter, retry
}
var lastErr error
selectedHost := hostIter()
for selectedHost != nil {
iter, retryType := execute(qry, selectedHost)
if iter.err == nil {
return iter
}
lastErr = iter.err
// Exit if retry policy decides to not retry anymore
if retryType == RetryType(255) {
if !getShouldRetry(qry) {
return iter
}
retryType = getRetryType(iter.err)
}
// If query is unsuccessful, check the error with RetryPolicy to retry
switch retryType {
case Retry:
// retry on the same host
continue
case Rethrow, Ignore:
return iter
case RetryNextHost:
// retry on the next host
selectedHost = hostIter()
continue
default:
// Undefined? Return nil and error, this will panic in the requester
return &Iter{err: ErrUnknownRetryType}
}
}
if lastErr != nil {
return &Iter{err: lastErr}
}
return &Iter{err: ErrNoConnections}
}
func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) {
select {
case results <- q.do(ctx, qry, hostIter):
case <-ctx.Done():
}
qry.releaseAfterExecution()
}

544
vendor/github.com/gocql/gocql/recreate.go generated vendored Normal file
View File

@@ -0,0 +1,544 @@
//go:build !cassandra
// +build !cassandra
// Copyright (C) 2017 ScyllaDB
package gocql
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"sort"
"strconv"
"strings"
"text/template"
)
// ToCQL returns a CQL query that ca be used to recreate keyspace with all
// user defined types, tables, indexes, functions, aggregates and views associated
// with this keyspace.
func (km *KeyspaceMetadata) ToCQL() (string, error) {
// Be aware that `CreateStmts` is not only a cache for ToCQL,
// but it also can be populated from response to `DESCRIBE KEYSPACE %s WITH INTERNALS`
if len(km.CreateStmts) != 0 {
return km.CreateStmts, nil
}
var sb strings.Builder
if err := km.keyspaceToCQL(&sb); err != nil {
return "", err
}
sortedTypes := km.typesSortedTopologically()
for _, tm := range sortedTypes {
if err := km.userTypeToCQL(&sb, tm); err != nil {
return "", err
}
}
for _, tm := range km.Tables {
if err := km.tableToCQL(&sb, km.Name, tm); err != nil {
return "", err
}
}
for _, im := range km.Indexes {
if err := km.indexToCQL(&sb, im); err != nil {
return "", err
}
}
for _, fm := range km.Functions {
if err := km.functionToCQL(&sb, km.Name, fm); err != nil {
return "", err
}
}
for _, am := range km.Aggregates {
if err := km.aggregateToCQL(&sb, am); err != nil {
return "", err
}
}
for _, vm := range km.Views {
if err := km.viewToCQL(&sb, vm); err != nil {
return "", err
}
}
km.CreateStmts = sb.String()
return km.CreateStmts, nil
}
func (km *KeyspaceMetadata) typesSortedTopologically() []*TypeMetadata {
sortedTypes := make([]*TypeMetadata, 0, len(km.Types))
for _, tm := range km.Types {
sortedTypes = append(sortedTypes, tm)
}
sort.Slice(sortedTypes, func(i, j int) bool {
for _, ft := range sortedTypes[j].FieldTypes {
if strings.Contains(ft, sortedTypes[i].Name) {
return true
}
}
return false
})
return sortedTypes
}
var tableCQLTemplate = template.Must(template.New("table").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"tableColumnToCQL": cqlHelpers.tableColumnToCQL,
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
}).
Parse(`
CREATE TABLE {{ .KeyspaceName }}.{{ .Tm.Name }} (
{{ tableColumnToCQL .Tm }}
) WITH {{ tablePropertiesToCQL .Tm.ClusteringColumns .Tm.Options .Tm.Flags .Tm.Extensions }};
`))
func (km *KeyspaceMetadata) tableToCQL(w io.Writer, kn string, tm *TableMetadata) error {
if err := tableCQLTemplate.Execute(w, map[string]interface{}{
"Tm": tm,
"KeyspaceName": kn,
}); err != nil {
return err
}
return nil
}
var functionTemplate = template.Must(template.New("functions").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"zip": cqlHelpers.zip,
"stripFrozen": cqlHelpers.stripFrozen,
}).
Parse(`
CREATE FUNCTION {{ .keyspaceName }}.{{ .fm.Name }} (
{{- range $i, $args := zip .fm.ArgumentNames .fm.ArgumentTypes }}
{{- if ne $i 0 }}, {{ end }}
{{- (index $args 0) }}
{{ stripFrozen (index $args 1) }}
{{- end -}})
{{ if .fm.CalledOnNullInput }}CALLED{{ else }}RETURNS NULL{{ end }} ON NULL INPUT
RETURNS {{ .fm.ReturnType }}
LANGUAGE {{ .fm.Language }}
AS $${{ .fm.Body }}$$;
`))
func (km *KeyspaceMetadata) functionToCQL(w io.Writer, keyspaceName string, fm *FunctionMetadata) error {
if err := functionTemplate.Execute(w, map[string]interface{}{
"fm": fm,
"keyspaceName": keyspaceName,
}); err != nil {
return err
}
return nil
}
var viewTemplate = template.Must(template.New("views").
Funcs(map[string]interface{}{
"zip": cqlHelpers.zip,
"partitionKeyString": cqlHelpers.partitionKeyString,
"tablePropertiesToCQL": cqlHelpers.tablePropertiesToCQL,
}).
Parse(`
CREATE MATERIALIZED VIEW {{ .vm.KeyspaceName }}.{{ .vm.ViewName }} AS
SELECT {{ if .vm.IncludeAllColumns }}*{{ else }}
{{- range $i, $col := .vm.OrderedColumns }}
{{- if ne $i 0 }}, {{ end }}
{{ $col }}
{{- end }}
{{- end }}
FROM {{ .vm.KeyspaceName }}.{{ .vm.BaseTableName }}
WHERE {{ .vm.WhereClause }}
PRIMARY KEY ({{ partitionKeyString .vm.PartitionKey .vm.ClusteringColumns }})
WITH {{ tablePropertiesToCQL .vm.ClusteringColumns .vm.Options .flags .vm.Extensions }};
`))
func (km *KeyspaceMetadata) viewToCQL(w io.Writer, vm *ViewMetadata) error {
if err := viewTemplate.Execute(w, map[string]interface{}{
"vm": vm,
"flags": []string{},
}); err != nil {
return err
}
return nil
}
var aggregatesTemplate = template.Must(template.New("aggregate").
Funcs(map[string]interface{}{
"stripFrozen": cqlHelpers.stripFrozen,
}).
Parse(`
CREATE AGGREGATE {{ .Keyspace }}.{{ .Name }}(
{{- range $i, $arg := .ArgumentTypes }}
{{- if ne $i 0 }}, {{ end }}
{{ stripFrozen $arg }}
{{- end -}})
SFUNC {{ .StateFunc.Name }}
STYPE {{ stripFrozen .StateType }}
{{- if ne .FinalFunc.Name "" }}
FINALFUNC {{ .FinalFunc.Name }}
{{- end -}}
{{- if ne .InitCond "" }}
INITCOND {{ .InitCond }}
{{- end -}}
;
`))
func (km *KeyspaceMetadata) aggregateToCQL(w io.Writer, am *AggregateMetadata) error {
if err := aggregatesTemplate.Execute(w, am); err != nil {
return err
}
return nil
}
var typeCQLTemplate = template.Must(template.New("types").
Funcs(map[string]interface{}{
"zip": cqlHelpers.zip,
}).
Parse(`
CREATE TYPE {{ .Keyspace }}.{{ .Name }} (
{{- range $i, $fields := zip .FieldNames .FieldTypes }} {{- if ne $i 0 }},{{ end }}
{{ index $fields 0 }} {{ index $fields 1 }}
{{- end }}
);
`))
func (km *KeyspaceMetadata) userTypeToCQL(w io.Writer, tm *TypeMetadata) error {
if err := typeCQLTemplate.Execute(w, tm); err != nil {
return err
}
return nil
}
func (km *KeyspaceMetadata) indexToCQL(w io.Writer, im *IndexMetadata) error {
// Scylla doesn't support any custom indexes
if im.Kind == IndexKindCustom {
return nil
}
options := im.Options
indexTarget := options["target"]
// secondary index
si := struct {
ClusteringKeys []string `json:"ck"`
PartitionKeys []string `json:"pk"`
}{}
if err := json.Unmarshal([]byte(indexTarget), &si); err == nil {
indexTarget = fmt.Sprintf("(%s), %s",
strings.Join(si.PartitionKeys, ","),
strings.Join(si.ClusteringKeys, ","),
)
}
_, err := fmt.Fprintf(w, "\nCREATE INDEX %s ON %s.%s (%s);\n",
im.Name,
im.KeyspaceName,
im.TableName,
indexTarget,
)
if err != nil {
return err
}
return nil
}
var keyspaceCQLTemplate = template.Must(template.New("keyspace").
Funcs(map[string]interface{}{
"escape": cqlHelpers.escape,
"fixStrategy": cqlHelpers.fixStrategy,
}).
Parse(`CREATE KEYSPACE {{ .Name }} WITH replication = {
'class': {{ escape ( fixStrategy .StrategyClass) }}
{{- range $key, $value := .StrategyOptions }},
{{ escape $key }}: {{ escape $value }}
{{- end }}
}{{ if not .DurableWrites }} AND durable_writes = 'false'{{ end }};
`))
func (km *KeyspaceMetadata) keyspaceToCQL(w io.Writer) error {
if err := keyspaceCQLTemplate.Execute(w, km); err != nil {
return err
}
return nil
}
func contains(in []string, v string) bool {
for _, e := range in {
if e == v {
return true
}
}
return false
}
type toCQLHelpers struct{}
var cqlHelpers = toCQLHelpers{}
func (h toCQLHelpers) zip(a []string, b []string) [][]string {
m := make([][]string, len(a))
for i := range a {
m[i] = []string{a[i], b[i]}
}
return m
}
func (h toCQLHelpers) escape(e interface{}) string {
switch v := e.(type) {
case int, float64:
return fmt.Sprint(v)
case bool:
if v {
return "true"
}
return "false"
case string:
return "'" + strings.ReplaceAll(v, "'", "''") + "'"
case []byte:
return string(v)
}
return ""
}
func (h toCQLHelpers) stripFrozen(v string) string {
return strings.TrimSuffix(strings.TrimPrefix(v, "frozen<"), ">")
}
func (h toCQLHelpers) fixStrategy(v string) string {
return strings.TrimPrefix(v, "org.apache.cassandra.locator.")
}
func (h toCQLHelpers) fixQuote(v string) string {
return strings.ReplaceAll(v, `"`, `'`)
}
func (h toCQLHelpers) tableOptionsToCQL(ops TableMetadataOptions) ([]string, error) {
opts := map[string]interface{}{
"bloom_filter_fp_chance": ops.BloomFilterFpChance,
"comment": ops.Comment,
"crc_check_chance": ops.CrcCheckChance,
"default_time_to_live": ops.DefaultTimeToLive,
"gc_grace_seconds": ops.GcGraceSeconds,
"max_index_interval": ops.MaxIndexInterval,
"memtable_flush_period_in_ms": ops.MemtableFlushPeriodInMs,
"min_index_interval": ops.MinIndexInterval,
"speculative_retry": ops.SpeculativeRetry,
}
var err error
opts["caching"], err = json.Marshal(ops.Caching)
if err != nil {
return nil, err
}
opts["compaction"], err = json.Marshal(ops.Compaction)
if err != nil {
return nil, err
}
opts["compression"], err = json.Marshal(ops.Compression)
if err != nil {
return nil, err
}
cdc, err := json.Marshal(ops.CDC)
if err != nil {
return nil, err
}
if string(cdc) != "null" {
opts["cdc"] = cdc
}
if ops.InMemory {
opts["in_memory"] = ops.InMemory
}
out := make([]string, 0, len(opts))
for key, opt := range opts {
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(opt))))
}
sort.Strings(out)
return out, nil
}
func (h toCQLHelpers) tableExtensionsToCQL(extensions map[string]interface{}) ([]string, error) {
exts := map[string]interface{}{}
if blob, ok := extensions["scylla_encryption_options"]; ok {
encOpts := &scyllaEncryptionOptions{}
if err := encOpts.UnmarshalBinary(blob.([]byte)); err != nil {
return nil, err
}
var err error
exts["scylla_encryption_options"], err = json.Marshal(encOpts)
if err != nil {
return nil, err
}
}
out := make([]string, 0, len(exts))
for key, ext := range exts {
out = append(out, fmt.Sprintf("%s = %s", key, h.fixQuote(h.escape(ext))))
}
sort.Strings(out)
return out, nil
}
func (h toCQLHelpers) tablePropertiesToCQL(cks []*ColumnMetadata, opts TableMetadataOptions,
flags []string, extensions map[string]interface{}) (string, error) {
var sb strings.Builder
var properties []string
compactStorage := len(flags) > 0 && (contains(flags, TableFlagDense) ||
contains(flags, TableFlagSuper) ||
!contains(flags, TableFlagCompound))
if compactStorage {
properties = append(properties, "COMPACT STORAGE")
}
if len(cks) > 0 {
var inner []string
for _, col := range cks {
inner = append(inner, fmt.Sprintf("%s %s", col.Name, col.ClusteringOrder))
}
properties = append(properties, fmt.Sprintf("CLUSTERING ORDER BY (%s)", strings.Join(inner, ", ")))
}
options, err := h.tableOptionsToCQL(opts)
if err != nil {
return "", err
}
properties = append(properties, options...)
exts, err := h.tableExtensionsToCQL(extensions)
if err != nil {
return "", err
}
properties = append(properties, exts...)
sb.WriteString(strings.Join(properties, "\n AND "))
return sb.String(), nil
}
func (h toCQLHelpers) tableColumnToCQL(tm *TableMetadata) string {
var sb strings.Builder
var columns []string
for _, cn := range tm.OrderedColumns {
cm := tm.Columns[cn]
column := fmt.Sprintf("%s %s", cn, cm.Type)
if cm.Kind == ColumnStatic {
column += " static"
}
columns = append(columns, column)
}
if len(tm.PartitionKey) == 1 && len(tm.ClusteringColumns) == 0 && len(columns) > 0 {
columns[0] += " PRIMARY KEY"
}
sb.WriteString(strings.Join(columns, ",\n "))
if len(tm.PartitionKey) > 1 || len(tm.ClusteringColumns) > 0 {
sb.WriteString(",\n PRIMARY KEY (")
sb.WriteString(h.partitionKeyString(tm.PartitionKey, tm.ClusteringColumns))
sb.WriteRune(')')
}
return sb.String()
}
func (h toCQLHelpers) partitionKeyString(pks, cks []*ColumnMetadata) string {
var sb strings.Builder
if len(pks) > 1 {
sb.WriteRune('(')
for i, pk := range pks {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(pk.Name)
}
sb.WriteRune(')')
} else {
sb.WriteString(pks[0].Name)
}
if len(cks) > 0 {
sb.WriteString(", ")
for i, ck := range cks {
if i != 0 {
sb.WriteString(", ")
}
sb.WriteString(ck.Name)
}
}
return sb.String()
}
type scyllaEncryptionOptions struct {
CipherAlgorithm string `json:"cipher_algorithm"`
SecretKeyStrength int `json:"secret_key_strength"`
KeyProvider string `json:"key_provider"`
SecretKeyFile string `json:"secret_key_file"`
}
// UnmarshalBinary deserializes blob into scyllaEncryptionOptions.
// Format:
// - 4 bytes - size of KV map
// Size times:
// - 4 bytes - length of key
// - len_of_key bytes - key
// - 4 bytes - length of value
// - len_of_value bytes - value
func (enc *scyllaEncryptionOptions) UnmarshalBinary(data []byte) error {
size := binary.LittleEndian.Uint32(data[0:4])
m := make(map[string]string, size)
off := uint32(4)
for i := uint32(0); i < size; i++ {
keyLen := binary.LittleEndian.Uint32(data[off : off+4])
off += 4
key := string(data[off : off+keyLen])
off += keyLen
valueLen := binary.LittleEndian.Uint32(data[off : off+4])
off += 4
value := string(data[off : off+valueLen])
off += valueLen
m[key] = value
}
enc.CipherAlgorithm = m["cipher_algorithm"]
enc.KeyProvider = m["key_provider"]
enc.SecretKeyFile = m["secret_key_file"]
if secretKeyStrength, ok := m["secret_key_strength"]; ok {
sks, err := strconv.Atoi(secretKeyStrength)
if err != nil {
return err
}
enc.SecretKeyStrength = sks
}
return nil
}

275
vendor/github.com/gocql/gocql/ring_describer.go generated vendored Normal file
View File

@@ -0,0 +1,275 @@
package gocql
import (
"context"
"errors"
"fmt"
"net"
"sync"
)
// Polls system.peers at a specific interval to find new hosts
type ringDescriber struct {
control controlConnection
cfg *ClusterConfig
logger StdLogger
mu sync.RWMutex
prevHosts []*HostInfo
prevPartitioner string
// hosts are the set of all hosts in the cassandra ring that we know of.
// key of map is host_id.
hosts map[string]*HostInfo
// hostIPToUUID maps host native address to host_id.
hostIPToUUID map[string]string
}
func (r *ringDescriber) setControlConn(c controlConnection) {
r.mu.Lock()
defer r.mu.Unlock()
r.control = c
}
// Ask the control node for the local host info
func (r *ringDescriber) getLocalHostInfo(conn ConnInterface) (*HostInfo, error) {
iter := conn.querySystem(context.TODO(), qrySystemLocal)
if iter == nil {
return nil, errNoControl
}
host, err := hostInfoFromIter(iter, nil, r.cfg.Port, r.cfg.translateAddressPort)
if err != nil {
return nil, fmt.Errorf("could not retrieve local host info: %w", err)
}
return host, nil
}
// Ask the control node for host info on all it's known peers
func (r *ringDescriber) getClusterPeerInfo(localHost *HostInfo, c ConnInterface) ([]*HostInfo, error) {
var iter *Iter
if c.getIsSchemaV2() {
iter = c.querySystem(context.TODO(), qrySystemPeersV2)
} else {
iter = c.querySystem(context.TODO(), qrySystemPeers)
}
if iter == nil {
return nil, errNoControl
}
rows, err := iter.SliceMap()
if err != nil {
// TODO(zariel): make typed error
return nil, fmt.Errorf("unable to fetch peer host info: %s", err)
}
return getPeersFromQuerySystemPeers(rows, r.cfg.Port, r.cfg.translateAddressPort, r.logger)
}
func getPeersFromQuerySystemPeers(querySystemPeerRows []map[string]interface{}, port int, translateAddressPort func(addr net.IP, port int) (net.IP, int), logger StdLogger) ([]*HostInfo, error) {
var peers []*HostInfo
for _, row := range querySystemPeerRows {
// extract all available info about the peer
host, err := hostInfoFromMap(row, &HostInfo{port: port}, translateAddressPort)
if err != nil {
return nil, err
} else if !isValidPeer(host) {
// If it's not a valid peer
logger.Printf("Found invalid peer '%s' "+
"Likely due to a gossip or snitch issue, this host will be ignored", host)
continue
} else if isZeroToken(host) {
continue
}
peers = append(peers, host)
}
return peers, nil
}
// Return true if the host is a valid peer
func isValidPeer(host *HostInfo) bool {
return !(len(host.RPCAddress()) == 0 ||
host.hostId == "" ||
host.dataCenter == "" ||
host.rack == "")
}
func isZeroToken(host *HostInfo) bool {
return len(host.tokens) == 0
}
// GetHostsFromSystem returns a list of hosts found via queries to system.local and system.peers
func (r *ringDescriber) GetHostsFromSystem() ([]*HostInfo, string, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.control == nil {
return r.prevHosts, r.prevPartitioner, errNoControl
}
ch := r.control.getConn()
localHost, err := r.getLocalHostInfo(ch.conn)
if err != nil {
return r.prevHosts, r.prevPartitioner, err
}
peerHosts, err := r.getClusterPeerInfo(localHost, ch.conn)
if err != nil {
return r.prevHosts, r.prevPartitioner, err
}
var hosts []*HostInfo
if !isZeroToken(localHost) {
hosts = []*HostInfo{localHost}
}
hosts = append(hosts, peerHosts...)
var partitioner string
if len(hosts) > 0 {
partitioner = hosts[0].Partitioner()
}
r.prevHosts = hosts
r.prevPartitioner = partitioner
return hosts, partitioner, nil
}
// Given an ip/port return HostInfo for the specified ip/port
func (r *ringDescriber) getHostInfo(hostID UUID) (*HostInfo, error) {
var host *HostInfo
for _, table := range []string{"system.peers", "system.local"} {
ch := r.control.getConn()
var iter *Iter
if ch.host.HostID() == hostID.String() {
host = ch.host
iter = nil
}
if table == "system.peers" {
if ch.conn.getIsSchemaV2() {
iter = ch.conn.querySystem(context.TODO(), qrySystemPeersV2)
} else {
iter = ch.conn.querySystem(context.TODO(), qrySystemPeers)
}
} else {
iter = ch.conn.query(context.TODO(), fmt.Sprintf("SELECT * FROM %s", table))
}
if iter != nil {
rows, err := iter.SliceMap()
if err != nil {
return nil, err
}
for _, row := range rows {
h, err := hostInfoFromMap(row, &HostInfo{port: r.cfg.Port}, r.cfg.translateAddressPort)
if err != nil {
return nil, err
}
if h.HostID() == hostID.String() {
host = h
break
}
}
}
}
if host == nil {
return nil, errors.New("unable to fetch host info: invalid control connection")
} else if host.invalidConnectAddr() {
return nil, fmt.Errorf("host ConnectAddress invalid ip=%v: %v", host.connectAddress, host)
}
return host, nil
}
func (r *ringDescriber) getHostByIP(ip string) (*HostInfo, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
hi, ok := r.hostIPToUUID[ip]
return r.hosts[hi], ok
}
func (r *ringDescriber) getHost(hostID string) *HostInfo {
r.mu.RLock()
host := r.hosts[hostID]
r.mu.RUnlock()
return host
}
func (r *ringDescriber) getHostsList() []*HostInfo {
r.mu.RLock()
hosts := make([]*HostInfo, 0, len(r.hosts))
for _, host := range r.hosts {
hosts = append(hosts, host)
}
r.mu.RUnlock()
return hosts
}
func (r *ringDescriber) getHostsMap() map[string]*HostInfo {
r.mu.RLock()
hosts := make(map[string]*HostInfo, len(r.hosts))
for k, v := range r.hosts {
hosts[k] = v
}
r.mu.RUnlock()
return hosts
}
func (r *ringDescriber) addOrUpdate(host *HostInfo) *HostInfo {
if existingHost, ok := r.addHostIfMissing(host); ok {
existingHost.update(host)
host = existingHost
}
return host
}
func (r *ringDescriber) addHostIfMissing(host *HostInfo) (*HostInfo, bool) {
if host.invalidConnectAddr() {
panic(fmt.Sprintf("invalid host: %v", host))
}
hostID := host.HostID()
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
if r.hostIPToUUID == nil {
r.hostIPToUUID = make(map[string]string)
}
existing, ok := r.hosts[hostID]
if !ok {
r.hosts[hostID] = host
r.hostIPToUUID[host.nodeToNodeAddress().String()] = hostID
existing = host
}
r.mu.Unlock()
return existing, ok
}
func (r *ringDescriber) removeHost(hostID string) bool {
r.mu.Lock()
if r.hosts == nil {
r.hosts = make(map[string]*HostInfo)
}
if r.hostIPToUUID == nil {
r.hostIPToUUID = make(map[string]string)
}
h, ok := r.hosts[hostID]
if ok {
delete(r.hostIPToUUID, h.nodeToNodeAddress().String())
}
delete(r.hosts, hostID)
r.mu.Unlock()
return ok
}

874
vendor/github.com/gocql/gocql/scylla.go generated vendored Normal file
View File

@@ -0,0 +1,874 @@
package gocql
import (
"context"
"crypto/tls"
"errors"
"fmt"
"math"
"net"
"strconv"
"strings"
"sync/atomic"
"syscall"
"time"
)
// scyllaSupported represents Scylla connection options as sent in SUPPORTED
// frame.
// FIXME: Should also follow `cqlProtocolExtension` interface.
type scyllaSupported struct {
shard int
nrShards int
msbIgnore uint64
partitioner string
shardingAlgorithm string
shardAwarePort uint16
shardAwarePortSSL uint16
lwtFlagMask int
}
// CQL Protocol extension interface for Scylla.
// Each extension is identified by a name and defines a way to serialize itself
// in STARTUP message payload.
type cqlProtocolExtension interface {
name() string
serialize() map[string]string
}
func findCQLProtoExtByName(exts []cqlProtocolExtension, name string) cqlProtocolExtension {
for i := range exts {
if exts[i].name() == name {
return exts[i]
}
}
return nil
}
// Top-level keys used for serialization/deserialization of CQL protocol
// extensions in SUPPORTED/STARTUP messages.
// Each key identifies a single extension.
const (
lwtAddMetadataMarkKey = "SCYLLA_LWT_ADD_METADATA_MARK"
rateLimitError = "SCYLLA_RATE_LIMIT_ERROR"
tabletsRoutingV1 = "TABLETS_ROUTING_V1"
)
// "tabletsRoutingV1" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to send a tablet information in `custom_payload`.
//
// Implements cqlProtocolExtension interface.
type tabletsRoutingV1Ext struct {
}
var _ cqlProtocolExtension = &tabletsRoutingV1Ext{}
// Factory function to deserialize and create an `tabletsRoutingV1Ext` instance
// from SUPPORTED message payload.
func newTabletsRoutingV1Ext(supported map[string][]string) *tabletsRoutingV1Ext {
if _, found := supported[tabletsRoutingV1]; found {
return &tabletsRoutingV1Ext{}
}
return nil
}
func (ext *tabletsRoutingV1Ext) serialize() map[string]string {
return map[string]string{
tabletsRoutingV1: "",
}
}
func (ext *tabletsRoutingV1Ext) name() string {
return tabletsRoutingV1
}
// "Rate limit" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to send a special kind of error.
//
// Implements cqlProtocolExtension interface.
type rateLimitExt struct {
rateLimitErrorCode int
}
var _ cqlProtocolExtension = &rateLimitExt{}
// Factory function to deserialize and create an `rateLimitExt` instance
// from SUPPORTED message payload.
func newRateLimitExt(supported map[string][]string) *rateLimitExt {
const rateLimitErrorCode = "ERROR_CODE"
if v, found := supported[rateLimitError]; found {
for i := range v {
splitVal := strings.Split(v[i], "=")
if splitVal[0] == rateLimitErrorCode {
var (
err error
errorCode int
)
if errorCode, err = strconv.Atoi(splitVal[1]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", rateLimitErrorCode, splitVal[1], err)
return nil
}
}
return &rateLimitExt{
rateLimitErrorCode: errorCode,
}
}
}
}
return nil
}
func (ext *rateLimitExt) serialize() map[string]string {
return map[string]string{
rateLimitError: "",
}
}
func (ext *rateLimitExt) name() string {
return rateLimitError
}
// "LWT prepared statements metadata mark" CQL Protocol Extension.
// This extension, if enabled (properly negotiated), allows Scylla server
// to set a special bit in prepared statements metadata, which would indicate
// whether the statement at hand is LWT statement or not.
//
// This is further used to consistently choose primary replicas in a predefined
// order for these queries, which can reduce contention over hot keys and thus
// increase LWT performance.
//
// Implements cqlProtocolExtension interface.
type lwtAddMetadataMarkExt struct {
lwtOptMetaBitMask int
}
var _ cqlProtocolExtension = &lwtAddMetadataMarkExt{}
// Factory function to deserialize and create an `lwtAddMetadataMarkExt` instance
// from SUPPORTED message payload.
func newLwtAddMetaMarkExt(supported map[string][]string) *lwtAddMetadataMarkExt {
const lwtOptMetaBitMaskKey = "LWT_OPTIMIZATION_META_BIT_MASK"
if v, found := supported[lwtAddMetadataMarkKey]; found {
for i := range v {
splitVal := strings.Split(v[i], "=")
if splitVal[0] == lwtOptMetaBitMaskKey {
var (
err error
bitMask int
)
if bitMask, err = strconv.Atoi(splitVal[1]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", lwtOptMetaBitMaskKey, splitVal[1], err)
return nil
}
}
return &lwtAddMetadataMarkExt{
lwtOptMetaBitMask: bitMask,
}
}
}
}
return nil
}
func (ext *lwtAddMetadataMarkExt) serialize() map[string]string {
return map[string]string{
lwtAddMetadataMarkKey: fmt.Sprintf("LWT_OPTIMIZATION_META_BIT_MASK=%d", ext.lwtOptMetaBitMask),
}
}
func (ext *lwtAddMetadataMarkExt) name() string {
return lwtAddMetadataMarkKey
}
func parseSupported(supported map[string][]string) scyllaSupported {
const (
scyllaShard = "SCYLLA_SHARD"
scyllaNrShards = "SCYLLA_NR_SHARDS"
scyllaPartitioner = "SCYLLA_PARTITIONER"
scyllaShardingAlgorithm = "SCYLLA_SHARDING_ALGORITHM"
scyllaShardingIgnoreMSB = "SCYLLA_SHARDING_IGNORE_MSB"
scyllaShardAwarePort = "SCYLLA_SHARD_AWARE_PORT"
scyllaShardAwarePortSSL = "SCYLLA_SHARD_AWARE_PORT_SSL"
)
var (
si scyllaSupported
err error
)
if s, ok := supported[scyllaShard]; ok {
if si.shard, err = strconv.Atoi(s[0]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShard, s, err)
}
}
}
if s, ok := supported[scyllaNrShards]; ok {
if si.nrShards, err = strconv.Atoi(s[0]); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaNrShards, s, err)
}
}
}
if s, ok := supported[scyllaShardingIgnoreMSB]; ok {
if si.msbIgnore, err = strconv.ParseUint(s[0], 10, 64); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardingIgnoreMSB, s, err)
}
}
}
if s, ok := supported[scyllaPartitioner]; ok {
si.partitioner = s[0]
}
if s, ok := supported[scyllaShardingAlgorithm]; ok {
si.shardingAlgorithm = s[0]
}
if s, ok := supported[scyllaShardAwarePort]; ok {
if shardAwarePort, err := strconv.ParseUint(s[0], 10, 16); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePort, s, err)
}
} else {
si.shardAwarePort = uint16(shardAwarePort)
}
}
if s, ok := supported[scyllaShardAwarePortSSL]; ok {
if shardAwarePortSSL, err := strconv.ParseUint(s[0], 10, 16); err != nil {
if gocqlDebug {
Logger.Printf("scylla: failed to parse %s value %v: %s", scyllaShardAwarePortSSL, s, err)
}
} else {
si.shardAwarePortSSL = uint16(shardAwarePortSSL)
}
}
if si.partitioner != "org.apache.cassandra.dht.Murmur3Partitioner" || si.shardingAlgorithm != "biased-token-round-robin" || si.nrShards == 0 || si.msbIgnore == 0 {
if gocqlDebug {
Logger.Printf("scylla: unsupported sharding configuration, partitioner=%s, algorithm=%s, no_shards=%d, msb_ignore=%d",
si.partitioner, si.shardingAlgorithm, si.nrShards, si.msbIgnore)
}
return scyllaSupported{}
}
return si
}
func parseCQLProtocolExtensions(supported map[string][]string) []cqlProtocolExtension {
exts := []cqlProtocolExtension{}
lwtExt := newLwtAddMetaMarkExt(supported)
if lwtExt != nil {
exts = append(exts, lwtExt)
}
rateLimitExt := newRateLimitExt(supported)
if rateLimitExt != nil {
exts = append(exts, rateLimitExt)
}
tabletsExt := newTabletsRoutingV1Ext(supported)
if tabletsExt != nil {
exts = append(exts, tabletsExt)
}
return exts
}
// isScyllaConn checks if conn is suitable for scyllaConnPicker.
func (conn *Conn) isScyllaConn() bool {
return conn.getScyllaSupported().nrShards != 0
}
// scyllaConnPicker is a specialised ConnPicker that selects connections based
// on token trying to get connection to a shard containing the given token.
// A list of excess connections is maintained to allow for lazy closing of
// connections to already opened shards. Keeping excess connections open helps
// reaching equilibrium faster since the likelihood of hitting the same shard
// decreases with the number of connections to the shard.
//
// scyllaConnPicker keeps track of the details about the shard-aware port.
// When used as a Dialer, it connects to the shard-aware port instead of the
// regular port (if the node supports it). For each subsequent connection
// it tries to make, the shard that it aims to connect to is chosen
// in a round-robin fashion.
type scyllaConnPicker struct {
address string
hostId string
shardAwareAddress string
conns []*Conn
excessConns []*Conn
nrConns int
nrShards int
msbIgnore uint64
pos uint64
lastAttemptedShard int
shardAwarePortDisabled bool
// Used to disable new connections to the shard-aware port temporarily
disableShardAwarePortUntil *atomic.Value
}
func newScyllaConnPicker(conn *Conn) *scyllaConnPicker {
addr := conn.Address()
hostId := conn.host.hostId
if conn.scyllaSupported.nrShards == 0 {
panic(fmt.Sprintf("scylla: %s not a sharded connection", addr))
}
if gocqlDebug {
Logger.Printf("scylla: %s new conn picker sharding options %+v", addr, conn.scyllaSupported)
}
var shardAwarePort uint16
if conn.session.connCfg.tlsConfig != nil {
shardAwarePort = conn.scyllaSupported.shardAwarePortSSL
} else {
shardAwarePort = conn.scyllaSupported.shardAwarePort
}
var shardAwareAddress string
if shardAwarePort != 0 {
tIP, tPort := conn.session.cfg.translateAddressPort(conn.host.UntranslatedConnectAddress(), int(shardAwarePort))
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
}
return &scyllaConnPicker{
address: addr,
hostId: hostId,
shardAwareAddress: shardAwareAddress,
nrShards: conn.scyllaSupported.nrShards,
msbIgnore: conn.scyllaSupported.msbIgnore,
lastAttemptedShard: 0,
shardAwarePortDisabled: conn.session.cfg.DisableShardAwarePort,
disableShardAwarePortUntil: new(atomic.Value),
}
}
func (p *scyllaConnPicker) Pick(t Token, qry ExecutableQuery) *Conn {
if len(p.conns) == 0 {
return nil
}
if t == nil {
return p.leastBusyConn()
}
mmt, ok := t.(int64Token)
// double check if that's murmur3 token
if !ok {
return nil
}
idx := -1
for _, conn := range p.conns {
if conn == nil {
continue
}
if qry != nil && conn.isTabletSupported() {
tablets := conn.session.getTablets()
// Search for tablets with Keyspace and Table from the Query
l, r := tablets.findTablets(qry.Keyspace(), qry.Table())
if l != -1 {
tablet := tablets.findTabletForToken(mmt, l, r)
for _, replica := range tablet.replicas {
if replica.hostId.String() == p.hostId {
idx = replica.shardId
}
}
}
}
break
}
if idx == -1 {
idx = p.shardOf(mmt)
}
if c := p.conns[idx]; c != nil {
// We have this shard's connection
// so let's give it to the caller.
// But only if it's not loaded too much and load is well distributed.
if qry != nil && qry.IsLWT() {
return c
}
return p.maybeReplaceWithLessBusyConnection(c)
}
return p.leastBusyConn()
}
func (p *scyllaConnPicker) maybeReplaceWithLessBusyConnection(c *Conn) *Conn {
if !isHeavyLoaded(c) {
return c
}
alternative := p.leastBusyConn()
if alternative == nil || alternative.AvailableStreams()*120 > c.AvailableStreams()*100 {
return c
} else {
return alternative
}
}
func isHeavyLoaded(c *Conn) bool {
return c.streams.NumStreams/2 > c.AvailableStreams()
}
func (p *scyllaConnPicker) leastBusyConn() *Conn {
var (
leastBusyConn *Conn
streamsAvailable int
)
idx := int(atomic.AddUint64(&p.pos, 1))
// find the conn which has the most available streams, this is racy
for i := range p.conns {
if conn := p.conns[(idx+i)%len(p.conns)]; conn != nil {
if streams := conn.AvailableStreams(); streams > streamsAvailable {
leastBusyConn = conn
streamsAvailable = streams
}
}
}
return leastBusyConn
}
func (p *scyllaConnPicker) shardOf(token int64Token) int {
shards := uint64(p.nrShards)
z := uint64(token+math.MinInt64) << p.msbIgnore
lo := z & 0xffffffff
hi := (z >> 32) & 0xffffffff
mul1 := lo * shards
mul2 := hi * shards
sum := (mul1 >> 32) + mul2
return int(sum >> 32)
}
func (p *scyllaConnPicker) Put(conn *Conn) {
var (
nrShards = conn.scyllaSupported.nrShards
shard = conn.scyllaSupported.shard
)
if nrShards == 0 {
panic(fmt.Sprintf("scylla: %s not a sharded connection", p.address))
}
if nrShards != len(p.conns) {
if nrShards != p.nrShards {
panic(fmt.Sprintf("scylla: %s invalid number of shards", p.address))
}
conns := p.conns
p.conns = make([]*Conn, nrShards, nrShards)
copy(p.conns, conns)
}
if c := p.conns[shard]; c != nil {
if conn.addr == p.shardAwareAddress {
// A connection made to the shard-aware port resulted in duplicate
// connection to the same shard being made. Because this is never
// intentional, it suggests that a NAT or AddressTranslator
// changes the source port along the way, therefore we can't trust
// the shard-aware port to return connection to the shard
// that we requested. Fall back to non-shard-aware port for some time.
Logger.Printf(
"scylla: %s connection to shard-aware address %s resulted in wrong shard being assigned; please check that you are not behind a NAT or AddressTranslater which changes source ports; falling back to non-shard-aware port for %v",
p.address,
p.shardAwareAddress,
scyllaShardAwarePortFallbackDuration,
)
until := time.Now().Add(scyllaShardAwarePortFallbackDuration)
p.disableShardAwarePortUntil.Store(until)
// Connections to shard-aware port do not influence how shards
// are chosen for the non-shard-aware port, therefore it can be
// closed immediately
closeConns(conn)
} else {
p.excessConns = append(p.excessConns, conn)
if gocqlDebug {
Logger.Printf("scylla: %s put shard %d excess connection total: %d missing: %d excess: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns, len(p.excessConns))
}
}
} else {
p.conns[shard] = conn
p.nrConns++
if gocqlDebug {
Logger.Printf("scylla: %s put shard %d connection total: %d missing: %d", p.address, shard, p.nrConns, p.nrShards-p.nrConns)
}
}
if p.shouldCloseExcessConns() {
p.closeExcessConns()
}
}
func (p *scyllaConnPicker) shouldCloseExcessConns() bool {
const maxExcessConnsFactor = 10
if p.nrConns >= p.nrShards {
return true
}
return len(p.excessConns) > maxExcessConnsFactor*p.nrShards
}
func (p *scyllaConnPicker) Remove(conn *Conn) {
shard := conn.scyllaSupported.shard
if conn.scyllaSupported.nrShards == 0 {
// It is possible for Remove to be called before the connection is added to the pool.
// Ignoring these connections here is safe.
if gocqlDebug {
Logger.Printf("scylla: %s has unknown sharding state, ignoring it", p.address)
}
return
}
if gocqlDebug {
Logger.Printf("scylla: %s remove shard %d connection", p.address, shard)
}
if p.conns[shard] != nil {
p.conns[shard] = nil
p.nrConns--
}
}
func (p *scyllaConnPicker) InFlight() int {
result := 0
for _, conn := range p.conns {
if conn != nil {
result = result + (conn.streams.InUse())
}
}
return result
}
func (p *scyllaConnPicker) Size() (int, int) {
return p.nrConns, p.nrShards - p.nrConns
}
func (p *scyllaConnPicker) Close() {
p.closeConns()
p.closeExcessConns()
}
func (p *scyllaConnPicker) closeConns() {
if len(p.conns) == 0 {
if gocqlDebug {
Logger.Printf("scylla: %s no connections to close", p.address)
}
return
}
conns := p.conns
p.conns = nil
p.nrConns = 0
if gocqlDebug {
Logger.Printf("scylla: %s closing %d connections", p.address, len(conns))
}
go closeConns(conns...)
}
func (p *scyllaConnPicker) closeExcessConns() {
if len(p.excessConns) == 0 {
if gocqlDebug {
Logger.Printf("scylla: %s no excess connections to close", p.address)
}
return
}
conns := p.excessConns
p.excessConns = nil
if gocqlDebug {
Logger.Printf("scylla: %s closing %d excess connections", p.address, len(conns))
}
go closeConns(conns...)
}
// Closing must be done outside of hostConnPool lock. If holding a lock
// a deadlock can occur when closing one of the connections returns error on close.
// See scylladb/gocql#53.
func closeConns(conns ...*Conn) {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
}
// NextShard returns the shardID to connect to.
// nrShard specifies how many shards the host has.
// If nrShards is zero, the caller shouldn't use shard-aware port.
func (p *scyllaConnPicker) NextShard() (shardID, nrShards int) {
if p.shardAwarePortDisabled {
return 0, 0
}
disableUntil, _ := p.disableShardAwarePortUntil.Load().(time.Time)
if time.Now().Before(disableUntil) {
// There is suspicion that the shard-aware-port is not reachable
// or misconfigured, fall back to the non-shard-aware port
return 0, 0
}
// Find the shard without a connection
// It's important to start counting from 1 here because we want
// to consider the next shard after the previously attempted one
for i := 1; i <= p.nrShards; i++ {
shardID := (p.lastAttemptedShard + i) % p.nrShards
if p.conns == nil || p.conns[shardID] == nil {
p.lastAttemptedShard = shardID
return shardID, p.nrShards
}
}
// We did not find an unallocated shard
// We will dial the non-shard-aware port
return 0, 0
}
// ShardDialer is like HostDialer but is shard-aware.
// If the driver wants to connect to a specific shard, it will call DialShard,
// otherwise it will call DialHost.
type ShardDialer interface {
HostDialer
// DialShard establishes a connection to the specified shard ID out of nrShards.
// The returned connection must be directly usable for CQL protocol,
// specifically DialShard is responsible also for setting up the TLS session if needed.
DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error)
}
// A dialer which dials a particular shard
type scyllaDialer struct {
dialer Dialer
logger StdLogger
tlsConfig *tls.Config
cfg *ClusterConfig
}
const scyllaShardAwarePortFallbackDuration time.Duration = 5 * time.Minute
func (sd *scyllaDialer) DialHost(ctx context.Context, host *HostInfo) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
addr := host.HostnameAndPort()
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
}
func (sd *scyllaDialer) DialShard(ctx context.Context, host *HostInfo, shardID, nrShards int) (*DialedHost, error) {
ip := host.ConnectAddress()
port := host.Port()
if !validIpAddr(ip) {
return nil, fmt.Errorf("host missing connect ip address: %v", ip)
} else if port == 0 {
return nil, fmt.Errorf("host missing port: %v", port)
}
iter := newScyllaPortIterator(shardID, nrShards)
addr := host.HostnameAndPort()
var shardAwarePort uint16
if sd.tlsConfig != nil {
shardAwarePort = host.ScyllaShardAwarePortTLS()
} else {
shardAwarePort = host.ScyllaShardAwarePort()
}
var shardAwareAddress string
if shardAwarePort != 0 {
tIP, tPort := sd.cfg.translateAddressPort(host.UntranslatedConnectAddress(), int(shardAwarePort))
shardAwareAddress = net.JoinHostPort(tIP.String(), strconv.Itoa(tPort))
}
if gocqlDebug {
sd.logger.Printf("scylla: connecting to shard %d", shardID)
}
conn, err := sd.dialShardAware(ctx, addr, shardAwareAddress, iter)
if err != nil {
return nil, err
}
return WrapTLS(ctx, conn, addr, sd.tlsConfig)
}
func (sd *scyllaDialer) dialShardAware(ctx context.Context, addr, shardAwareAddr string, iter *scyllaPortIterator) (net.Conn, error) {
for {
port, ok := iter.Next()
if !ok {
// We exhausted ports to connect from. Try the non-shard-aware port.
return sd.dialer.DialContext(ctx, "tcp", addr)
}
ctxWithPort := context.WithValue(ctx, scyllaSourcePortCtx{}, port)
conn, err := sd.dialer.DialContext(ctxWithPort, "tcp", shardAwareAddr)
if isLocalAddrInUseErr(err) {
// This indicates that the source port is already in use
// We can immediately retry with another source port for this shard
continue
} else if err != nil {
conn, err := sd.dialer.DialContext(ctx, "tcp", addr)
if err == nil {
// We failed to connect to the shard-aware port, but succeeded
// in connecting to the non-shard-aware port. This might
// indicate that the shard-aware port is just not reachable,
// but we may also be unlucky and the node became reachable
// just after we tried the first connection.
// We can't avoid false positives here, so I'm putting it
// behind a debug flag.
if gocqlDebug {
sd.logger.Printf(
"scylla: %s couldn't connect to shard-aware address while the non-shard-aware address %s is available; this might be an issue with ",
addr,
shardAwareAddr,
)
}
}
return conn, err
}
return conn, err
}
}
// ErrScyllaSourcePortAlreadyInUse An error value which can returned from
// a custom dialer implementation to indicate that the requested source port
// to dial from is already in use
var ErrScyllaSourcePortAlreadyInUse = errors.New("scylla: source port is already in use")
func isLocalAddrInUseErr(err error) bool {
return errors.Is(err, syscall.EADDRINUSE) || errors.Is(err, ErrScyllaSourcePortAlreadyInUse)
}
// ScyllaShardAwareDialer wraps a net.Dialer, but uses a source port specified by gocql when connecting.
//
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
// a new connection by connecting from a specific source port. If you are using your own net.Dialer
// in ClusterConfig, you can use ScyllaShardAwareDialer to "upgrade" it so that it connects
// from the source port chosen by gocql.
//
// Please note that ScyllaShardAwareDialer overwrites the LocalAddr field in order to choose
// the right source port for connection.
type ScyllaShardAwareDialer struct {
net.Dialer
}
func (d *ScyllaShardAwareDialer) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
sourcePort := ScyllaGetSourcePort(ctx)
if sourcePort == 0 {
return d.Dialer.DialContext(ctx, network, addr)
}
dialerWithLocalAddr := d.Dialer
dialerWithLocalAddr.LocalAddr, err = net.ResolveTCPAddr(network, fmt.Sprintf(":%d", sourcePort))
if err != nil {
return nil, err
}
return dialerWithLocalAddr.DialContext(ctx, network, addr)
}
type scyllaPortIterator struct {
currentPort int
shardCount int
}
const (
scyllaPortBasedBalancingMin = 0x8000
scyllaPortBasedBalancingMax = 0xFFFF
)
func newScyllaPortIterator(shardID, shardCount int) *scyllaPortIterator {
if shardCount == 0 {
panic("shardCount cannot be 0")
}
// Find the smallest port p such that p >= min and p % shardCount == shardID
port := scyllaPortBasedBalancingMin - scyllaShardForSourcePort(scyllaPortBasedBalancingMin, shardCount) + shardID
if port < scyllaPortBasedBalancingMin {
port += shardCount
}
return &scyllaPortIterator{
currentPort: port,
shardCount: shardCount,
}
}
func (spi *scyllaPortIterator) Next() (uint16, bool) {
if spi == nil {
return 0, false
}
p := spi.currentPort
if p > scyllaPortBasedBalancingMax {
return 0, false
}
spi.currentPort += spi.shardCount
return uint16(p), true
}
func scyllaShardForSourcePort(sourcePort uint16, shardCount int) int {
return int(sourcePort) % shardCount
}
type scyllaSourcePortCtx struct{}
// ScyllaGetSourcePort returns the source port that should be used when connecting to a node.
//
// Unlike in the case standard native transport ports, gocql can choose which shard will handle
// a new connection at the shard-aware port by connecting from a specific source port. Therefore,
// if you are using a custom Dialer and your nodes expose shard-aware ports, your dialer should
// use the source port specified by gocql.
//
// If this function returns 0, then your dialer can use any source port.
//
// If you aren't using a custom dialer, gocql will use a default one which uses appropriate source port.
// If you are using net.Dialer, consider wrapping it in a gocql.ScyllaShardAwareDialer.
func ScyllaGetSourcePort(ctx context.Context) uint16 {
sourcePort, _ := ctx.Value(scyllaSourcePortCtx{}).(uint16)
return sourcePort
}
// Returns a partitioner specific to the table, or "nil"
// if the cluster-global partitioner should be used
func scyllaGetTablePartitioner(session *Session, keyspaceName, tableName string) (Partitioner, error) {
isCdc, err := scyllaIsCdcTable(session, keyspaceName, tableName)
if err != nil {
return nil, err
}
if isCdc {
return scyllaCDCPartitioner{}, nil
}
return nil, nil
}

93
vendor/github.com/gocql/gocql/scylla_cdc.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
package gocql
import (
"encoding/binary"
"math"
"strings"
)
// cdc partitioner
const (
scyllaCDCPartitionerName = "CDCPartitioner"
scyllaCDCPartitionerFullName = "com.scylladb.dht.CDCPartitioner"
scyllaCDCPartitionKeyLength = 16
scyllaCDCVersionMask = 0x0F
scyllaCDCMinSupportedVersion = 1
scyllaCDCMaxSupportedVersion = 1
scyllaCDCMinToken = int64Token(math.MinInt64)
scyllaCDCLogTableNameSuffix = "_scylla_cdc_log"
scyllaCDCExtensionName = "cdc"
)
type scyllaCDCPartitioner struct{}
var _ Partitioner = scyllaCDCPartitioner{}
func (p scyllaCDCPartitioner) Name() string {
return scyllaCDCPartitionerName
}
func (p scyllaCDCPartitioner) Hash(partitionKey []byte) Token {
if len(partitionKey) < 8 {
// The key is too short to extract any sensible token,
// so return the min token instead
if gocqlDebug {
Logger.Printf("scylla: cdc partition key too short: %d < 8", len(partitionKey))
}
return scyllaCDCMinToken
}
upperQword := binary.BigEndian.Uint64(partitionKey[0:])
if gocqlDebug {
// In debug mode, do some more checks
if len(partitionKey) != scyllaCDCPartitionKeyLength {
// The token has unrecognized format, but the first quadword
// should be the token value that we want
Logger.Printf("scylla: wrong size of cdc partition key: %d", len(partitionKey))
}
lowerQword := binary.BigEndian.Uint64(partitionKey[8:])
version := lowerQword & scyllaCDCVersionMask
if version < scyllaCDCMinSupportedVersion || version > scyllaCDCMaxSupportedVersion {
// We don't support this version yet,
// the token may be wrong
Logger.Printf(
"scylla: unsupported version: %d is not in range [%d, %d]",
version,
scyllaCDCMinSupportedVersion,
scyllaCDCMaxSupportedVersion,
)
}
}
return int64Token(upperQword)
}
func (p scyllaCDCPartitioner) ParseString(str string) Token {
return parseInt64Token(str)
}
func scyllaIsCdcTable(session *Session, keyspaceName, tableName string) (bool, error) {
if !strings.HasSuffix(tableName, scyllaCDCLogTableNameSuffix) {
// Not a CDC table, use the default partitioner
return false, nil
}
// Get the table metadata to see if it has the cdc partitioner set
keyspaceMeta, err := session.KeyspaceMetadata(keyspaceName)
if err != nil {
return false, err
}
tableMeta, ok := keyspaceMeta.Tables[tableName]
if !ok {
return false, ErrNoMetadata
}
return tableMeta.Options.Partitioner == scyllaCDCPartitionerFullName, nil
}

View File

@@ -0,0 +1,28 @@
package ascii
import (
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case string:
return EncString(v)
case *string:
return EncStringR(v)
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(rv)
}
return EncReflectR(rv)
}
}

View File

@@ -0,0 +1,61 @@
package ascii
import (
"fmt"
"reflect"
)
func EncString(v string) ([]byte, error) {
return encString(v), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return encString(*v), nil
}
func EncBytes(v []byte) ([]byte, error) {
return v, nil
}
func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return *v, nil
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.String:
return encString(v.String()), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return EncBytes(v.Bytes())
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encString(v string) []byte {
if v == "" {
return make([]byte, 0)
}
return []byte(v)
}

View File

@@ -0,0 +1,33 @@
package ascii
import (
"fmt"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
case *[]byte:
return DecBytes(data, v)
case **[]byte:
return DecBytesR(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,166 @@
package ascii
import (
"fmt"
"reflect"
)
func errInvalidData(p []byte) error {
for i := range p {
if p[i] > 127 {
return fmt.Errorf("failed to unmarshal ascii: invalid charester %s", string(p[i]))
}
}
return nil
}
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal ascii: can not unmarshal into nil reference(%T)(%[1]v)", v)
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
*v = decString(p)
return errInvalidData(p)
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
*v = decStringR(p)
return errInvalidData(p)
}
func DecBytes(p []byte, v *[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytes(p)
return errInvalidData(p)
}
func DecBytesR(p []byte, v **[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytesR(p)
return errInvalidData(p)
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch v = v.Elem(); v.Kind() {
case reflect.String:
v.SetString(decString(p))
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
v.SetBytes(decBytes(p))
default:
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return errInvalidData(p)
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch ev := v.Type().Elem().Elem(); ev.Kind() {
case reflect.String:
return decReflectStringR(p, v)
case reflect.Slice:
if ev.Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
return decReflectBytesR(p, v)
default:
return fmt.Errorf("failed to unmarshal ascii: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectStringR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
v.Elem().Set(reflect.New(v.Type().Elem().Elem()))
}
return nil
}
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(string(p))
v.Elem().Set(val)
return errInvalidData(p)
}
func decReflectBytesR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(make([]byte, 0))
v.Elem().Set(val)
}
return nil
}
tmp := make([]byte, len(p))
copy(tmp, p)
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(tmp)
v.Elem().Set(val)
return errInvalidData(p)
}
func decString(p []byte) string {
if len(p) == 0 {
return ""
}
return string(p)
}
func decStringR(p []byte) *string {
if len(p) == 0 {
if p == nil {
return nil
}
return new(string)
}
tmp := string(p)
return &tmp
}
func decBytes(p []byte) []byte {
if len(p) == 0 {
if p == nil {
return nil
}
return make([]byte, 0)
}
tmp := make([]byte, len(p))
copy(tmp, p)
return tmp
}
func decBytesR(p []byte) *[]byte {
if len(p) == 0 {
if p == nil {
return nil
}
tmp := make([]byte, 0)
return &tmp
}
tmp := make([]byte, len(p))
copy(tmp, p)
return &tmp
}

View File

@@ -0,0 +1,74 @@
package bigint
import (
"math/big"
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int16:
return EncInt16(v)
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)
case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)
case big.Int:
return EncBigInt(v)
case string:
return EncString(v)
case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)
case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)
case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,206 @@
package bigint
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var (
maxBigInt = big.NewInt(math.MaxInt64)
minBigInt = big.NewInt(math.MinInt64)
)
func EncInt8(v int8) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, 255, byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, 0, byte(v)}, nil
}
func EncInt8R(v *int8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt8(*v)
}
func EncInt16(v int16) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncInt16R(v *int16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt16(*v)
}
func EncInt32(v int32) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt32(*v)
}
func EncInt64(v int64) ([]byte, error) {
return encInt64(v), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncInt(v int) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncIntR(v *int) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt(*v)
}
func EncUint8(v uint8) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, 0, v}, nil
}
func EncUint8R(v *uint8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint8(*v)
}
func EncUint16(v uint16) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncUint16R(v *uint16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint16(*v)
}
func EncUint32(v uint32) ([]byte, error) {
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint32(*v)
}
func EncUint64(v uint64) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint64R(v *uint64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint64(*v)
}
func EncUint(v uint) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUintR(v *uint) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint(*v)
}
func EncBigInt(v big.Int) ([]byte, error) {
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncBigIntR(v *big.Int) ([]byte, error) {
if v == nil {
return nil, nil
}
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal bigint: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v, err)
}
return encInt64(n), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return EncInt64(v.Int())
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return EncUint64(v.Uint())
case reflect.String:
val := v.String()
if val == "" {
return nil, nil
}
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal bigint: can not marshal (%T)(%[1]v) %s", v.Interface(), err)
}
return encInt64(n), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

View File

@@ -0,0 +1,81 @@
package bigint
import (
"fmt"
"math/big"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int8:
return DecInt8(data, v)
case *int16:
return DecInt16(data, v)
case *int32:
return DecInt32(data, v)
case *int64:
return DecInt64(data, v)
case *int:
return DecInt(data, v)
case *uint8:
return DecUint8(data, v)
case *uint16:
return DecUint16(data, v)
case *uint32:
return DecUint32(data, v)
case *uint64:
return DecUint64(data, v)
case *uint:
return DecUint(data, v)
case *big.Int:
return DecBigInt(data, v)
case *string:
return DecString(data, v)
case **int8:
return DecInt8R(data, v)
case **int16:
return DecInt16R(data, v)
case **int32:
return DecInt32R(data, v)
case **int64:
return DecInt64R(data, v)
case **int:
return DecIntR(data, v)
case **uint8:
return DecUint8R(data, v)
case **uint16:
return DecUint16R(data, v)
case **uint32:
return DecUint32R(data, v)
case **uint64:
return DecUint64R(data, v)
case **uint:
return DecUintR(data, v)
case **big.Int:
return DecBigIntR(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyInt int) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal bigint: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,841 @@
package bigint
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal bigint: the length of the data should be 0 or 8")
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal bigint: can not unmarshal into nil reference (%T)(%[1]v))", v)
}
func DecInt8(p []byte, v *int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int8, the data should be in the int8 range")
}
*v = int8(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt8R(p []byte, v **int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int8)
}
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int8, the data should be in the int8 range")
}
tmp := int8(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt16(p []byte, v *int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int16, the data should be in the int16 range")
}
*v = int16(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt16R(p []byte, v **int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int16)
}
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int16, the data should be in the int16 range")
}
tmp := int16(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt32(p []byte, v *int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int32, the data should be in the int32 range")
}
*v = int32(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt32R(p []byte, v **int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int32)
}
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into int32, the data should be in the int32 range")
}
tmp := int32(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt64(p []byte, v *int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decInt64(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt64R(p []byte, v **int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int64)
}
case 8:
val := decInt64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecInt(p []byte, v *int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = int(p[0])<<56 | int(p[1])<<48 | int(p[2])<<40 | int(p[3])<<32 | int(p[4])<<24 | int(p[5])<<16 | int(p[6])<<8 | int(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecIntR(p []byte, v **int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int)
}
case 8:
val := int(p[0])<<56 | int(p[1])<<48 | int(p[2])<<40 | int(p[3])<<32 | int(p[4])<<24 | int(p[5])<<16 | int(p[6])<<8 | int(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint8(p []byte, v *uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint8, the data should be in the uint8 range")
}
*v = p[7]
default:
return errWrongDataLen
}
return nil
}
func DecUint8R(p []byte, v **uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint8)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint8, the data should be in the uint8 range")
}
val := p[7]
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint16(p []byte, v *uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint16, the data should be in the uint16 range")
}
*v = uint16(p[6])<<8 | uint16(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUint16R(p []byte, v **uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint16)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint16, the data should be in the uint16 range")
}
val := uint16(p[6])<<8 | uint16(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint32(p []byte, v *uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint32, the data should be in the uint32 range")
}
*v = uint32(p[4])<<24 | uint32(p[5])<<16 | uint32(p[6])<<8 | uint32(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUint32R(p []byte, v **uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint32)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into uint32, the data should be in the uint32 range")
}
val := uint32(p[4])<<24 | uint32(p[5])<<16 | uint32(p[6])<<8 | uint32(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint64(p []byte, v *uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decUint64(p)
default:
return errWrongDataLen
}
return nil
}
func DecUint64R(p []byte, v **uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint64)
}
case 8:
val := decUint64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint(p []byte, v *uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = uint(p[0])<<56 | uint(p[1])<<48 | uint(p[2])<<40 | uint(p[3])<<32 | uint(p[4])<<24 | uint(p[5])<<16 | uint(p[6])<<8 | uint(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUintR(p []byte, v **uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint)
}
case 8:
val := uint(p[0])<<56 | uint(p[1])<<48 | uint(p[2])<<40 | uint(p[3])<<32 | uint(p[4])<<24 | uint(p[5])<<16 | uint(p[6])<<8 | uint(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = ""
} else {
*v = "0"
}
case 8:
*v = strconv.FormatInt(decInt64(p), 10)
default:
return errWrongDataLen
}
return nil
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := "0"
*v = &val
}
case 8:
val := strconv.FormatInt(decInt64(p), 10)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecBigInt(p []byte, v *big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
v.SetInt64(0)
case 8:
v.SetInt64(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecBigIntR(p []byte, v **big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(big.Int)
}
case 8:
*v = big.NewInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal bigint: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.Int8:
return decReflectInt8(p, v)
case reflect.Int16:
return decReflectInt16(p, v)
case reflect.Int32:
return decReflectInt32(p, v)
case reflect.Int64, reflect.Int:
return decReflectInts(p, v)
case reflect.Uint8:
return decReflectUint8(p, v)
case reflect.Uint16:
return decReflectUint16(p, v)
case reflect.Uint32:
return decReflectUint32(p, v)
case reflect.Uint64, reflect.Uint:
return decReflectUints(p, v)
case reflect.String:
return decReflectString(p, v)
default:
return fmt.Errorf("failed to unmarshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int8 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int16 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int32 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInts(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
v.SetInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint8 range", v.Interface())
}
v.SetUint(uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint16 range", v.Interface())
}
v.SetUint(uint64(p[6])<<8 | uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint32 range", v.Interface())
}
v.SetUint(uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUints(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
v.SetUint(decUint64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectString(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.SetString("")
} else {
v.SetString("0")
}
case 8:
v.SetString(strconv.FormatInt(decInt64(p), 10))
default:
return errWrongDataLen
}
return nil
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal bigint: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Int8:
return decReflectInt8R(p, v)
case reflect.Int16:
return decReflectInt16R(p, v)
case reflect.Int32:
return decReflectInt32R(p, v)
case reflect.Int64, reflect.Int:
return decReflectIntsR(p, v)
case reflect.Uint8:
return decReflectUint8R(p, v)
case reflect.Uint16:
return decReflectUint16R(p, v)
case reflect.Uint32:
return decReflectUint32R(p, v)
case reflect.Uint64, reflect.Uint:
return decReflectUintsR(p, v)
case reflect.String:
return decReflectStringR(p, v)
default:
return fmt.Errorf("failed to unmarshal bigint: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int8 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int16 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the int32 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectIntsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetInt(decInt64(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint8 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint16 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[6])<<8 | uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal bigint: to unmarshal into %T, the data should be in the uint32 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUintsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetUint(decUint64(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectStringR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetString("0")
}
v.Elem().Set(val)
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(strconv.FormatInt(decInt64(p), 10))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectNullableR(p []byte, v reflect.Value) reflect.Value {
if p == nil {
return reflect.Zero(v.Elem().Type())
}
return reflect.New(v.Type().Elem().Elem())
}
func decInt64(p []byte) int64 {
return int64(p[0])<<56 | int64(p[1])<<48 | int64(p[2])<<40 | int64(p[3])<<32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7])
}
func decUint64(p []byte) uint64 {
return uint64(p[0])<<56 | uint64(p[1])<<48 | uint64(p[2])<<40 | uint64(p[3])<<32 | uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7])
}

View File

@@ -0,0 +1,28 @@
package blob
import (
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case string:
return EncString(v)
case *string:
return EncStringR(v)
case []byte:
return EncBytes(v)
case *[]byte:
return EncBytesR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(rv)
}
return EncReflectR(rv)
}
}

View File

@@ -0,0 +1,61 @@
package blob
import (
"fmt"
"reflect"
)
func EncString(v string) ([]byte, error) {
return encString(v), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return encString(*v), nil
}
func EncBytes(v []byte) ([]byte, error) {
return v, nil
}
func EncBytesR(v *[]byte) ([]byte, error) {
if v == nil {
return nil, nil
}
return *v, nil
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.String:
return encString(v.String()), nil
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return nil, fmt.Errorf("failed to marshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
return EncBytes(v.Bytes())
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encString(v string) []byte {
if v == "" {
return make([]byte, 0)
}
return []byte(v)
}

View File

@@ -0,0 +1,35 @@
package blob
import (
"fmt"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
case *[]byte:
return DecBytes(data, v)
case **[]byte:
return DecBytesR(data, v)
case *interface{}:
return DecInterface(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal blob: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,167 @@
package blob
import (
"fmt"
"reflect"
)
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal blob: can not unmarshal into nil reference(%T)(%[1]v)", v)
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
*v = decString(p)
return nil
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
*v = decStringR(p)
return nil
}
func DecBytes(p []byte, v *[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytes(p)
return nil
}
func DecBytesR(p []byte, v **[]byte) error {
if v == nil {
return errNilReference(v)
}
*v = decBytesR(p)
return nil
}
func DecInterface(p []byte, v *interface{}) error {
if v == nil {
return errNilReference(v)
}
*v = decBytes(p)
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch v = v.Elem(); v.Kind() {
case reflect.String:
v.SetString(decString(p))
case reflect.Slice:
if v.Type().Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
v.SetBytes(decBytes(p))
case reflect.Interface:
v.Set(reflect.ValueOf(decBytes(p)))
default:
return fmt.Errorf("failed to unmarshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
return nil
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch ev := v.Type().Elem().Elem(); ev.Kind() {
case reflect.String:
return decReflectStringR(p, v)
case reflect.Slice:
if ev.Elem().Kind() != reflect.Uint8 {
return fmt.Errorf("failed to marshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
return decReflectBytesR(p, v)
default:
return fmt.Errorf("failed to unmarshal blob: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectStringR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Type().Elem()))
} else {
v.Elem().Set(reflect.New(v.Type().Elem().Elem()))
}
return nil
}
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(string(p))
v.Elem().Set(val)
return nil
}
func decReflectBytesR(p []byte, v reflect.Value) error {
if len(p) == 0 {
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(make([]byte, 0))
v.Elem().Set(val)
}
return nil
}
tmp := make([]byte, len(p))
copy(tmp, p)
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBytes(tmp)
v.Elem().Set(val)
return nil
}
func decString(p []byte) string {
if len(p) == 0 {
return ""
}
return string(p)
}
func decStringR(p []byte) *string {
if len(p) == 0 {
if p == nil {
return nil
}
return new(string)
}
tmp := string(p)
return &tmp
}
func decBytes(p []byte) []byte {
if len(p) == 0 {
if p == nil {
return nil
}
return make([]byte, 0)
}
tmp := make([]byte, len(p))
copy(tmp, p)
return tmp
}
func decBytesR(p []byte) *[]byte {
if len(p) == 0 {
if p == nil {
return nil
}
tmp := make([]byte, 0)
return &tmp
}
tmp := make([]byte, len(p))
copy(tmp, p)
return &tmp
}

View File

@@ -0,0 +1,24 @@
package boolean
import (
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case bool:
return EncBool(v)
case *bool:
return EncBoolR(v)
default:
// Custom types (type MyBool bool) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,45 @@
package boolean
import (
"fmt"
"reflect"
)
func EncBool(v bool) ([]byte, error) {
return encBool(v), nil
}
func EncBoolR(v *bool) ([]byte, error) {
if v == nil {
return nil, nil
}
return encBool(*v), nil
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Bool:
return encBool(v.Bool()), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encBool(v bool) []byte {
if v {
return []byte{1}
}
return []byte{0}
}

View File

@@ -0,0 +1,29 @@
package boolean
import (
"fmt"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *bool:
return DecBool(data, v)
case **bool:
return DecBoolR(data, v)
default:
// Custom types (type MyBool bool) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,108 @@
package boolean
import (
"fmt"
"reflect"
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal boolean: the length of the data should be 0 or 1")
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal boolean: can not unmarshal into nil reference(%T)(%[1]v)", v)
}
func DecBool(p []byte, v *bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = false
case 1:
*v = decBool(p)
default:
return errWrongDataLen
}
return nil
}
func DecBoolR(p []byte, v **bool) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(bool)
}
case 1:
val := decBool(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch v = v.Elem(); v.Kind() {
case reflect.Bool:
return decReflectBool(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return errNilReference(v)
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Bool:
return decReflectBoolR(p, v)
default:
return fmt.Errorf("failed to unmarshal boolean: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectBool(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetBool(false)
case 1:
v.SetBool(decBool(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectBoolR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.Elem().Set(reflect.Zero(v.Type().Elem()))
} else {
val := reflect.New(v.Type().Elem().Elem())
v.Elem().Set(val)
}
case 1:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetBool(decBool(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decBool(p []byte) bool {
return p[0] != 0
}

View File

@@ -0,0 +1,74 @@
package counter
import (
"math/big"
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int16:
return EncInt16(v)
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)
case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)
case big.Int:
return EncBigInt(v)
case string:
return EncString(v)
case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)
case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)
case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,206 @@
package counter
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var (
maxBigInt = big.NewInt(math.MaxInt64)
minBigInt = big.NewInt(math.MinInt64)
)
func EncInt8(v int8) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, 255, byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, 0, byte(v)}, nil
}
func EncInt8R(v *int8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt8(*v)
}
func EncInt16(v int16) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncInt16R(v *int16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt16(*v)
}
func EncInt32(v int32) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt32(*v)
}
func EncInt64(v int64) ([]byte, error) {
return encInt64(v), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncInt(v int) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncIntR(v *int) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt(*v)
}
func EncUint8(v uint8) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, 0, v}, nil
}
func EncUint8R(v *uint8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint8(*v)
}
func EncUint16(v uint16) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}
func EncUint16R(v *uint16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint16(*v)
}
func EncUint32(v uint32) ([]byte, error) {
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint32(*v)
}
func EncUint64(v uint64) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint64R(v *uint64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint64(*v)
}
func EncUint(v uint) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUintR(v *uint) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint(*v)
}
func EncBigInt(v big.Int) ([]byte, error) {
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal counter: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncBigIntR(v *big.Int) ([]byte, error) {
if v == nil {
return nil, nil
}
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal counter: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal counter: can not marshal %#v %s", v, err)
}
return encInt64(n), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return EncInt64(v.Int())
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return EncUint64(v.Uint())
case reflect.String:
val := v.String()
if val == "" {
return nil, nil
}
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal counter: can not marshal (%T)(%[1]v) %s", v.Interface(), err)
}
return encInt64(n), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal counter: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal counter: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

View File

@@ -0,0 +1,81 @@
package counter
import (
"fmt"
"math/big"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int8:
return DecInt8(data, v)
case *int16:
return DecInt16(data, v)
case *int32:
return DecInt32(data, v)
case *int64:
return DecInt64(data, v)
case *int:
return DecInt(data, v)
case *uint8:
return DecUint8(data, v)
case *uint16:
return DecUint16(data, v)
case *uint32:
return DecUint32(data, v)
case *uint64:
return DecUint64(data, v)
case *uint:
return DecUint(data, v)
case *big.Int:
return DecBigInt(data, v)
case *string:
return DecString(data, v)
case **int8:
return DecInt8R(data, v)
case **int16:
return DecInt16R(data, v)
case **int32:
return DecInt32R(data, v)
case **int64:
return DecInt64R(data, v)
case **int:
return DecIntR(data, v)
case **uint8:
return DecUint8R(data, v)
case **uint16:
return DecUint16R(data, v)
case **uint32:
return DecUint32R(data, v)
case **uint64:
return DecUint64R(data, v)
case **uint:
return DecUintR(data, v)
case **big.Int:
return DecBigIntR(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyInt int) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal counter: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,841 @@
package counter
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal counter: the length of the data should be 0 or 8")
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal counter: can not unmarshal into nil reference (%T)(%[1]v))", v)
}
func DecInt8(p []byte, v *int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int8, the data should be in the int8 range")
}
*v = int8(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt8R(p []byte, v **int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int8)
}
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int8, the data should be in the int8 range")
}
tmp := int8(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt16(p []byte, v *int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int16, the data should be in the int16 range")
}
*v = int16(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt16R(p []byte, v **int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int16)
}
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int16, the data should be in the int16 range")
}
tmp := int16(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt32(p []byte, v *int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int32, the data should be in the int32 range")
}
*v = int32(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt32R(p []byte, v **int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int32)
}
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into int32, the data should be in the int32 range")
}
tmp := int32(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt64(p []byte, v *int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decInt64(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt64R(p []byte, v **int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int64)
}
case 8:
val := decInt64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecInt(p []byte, v *int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = int(p[0])<<56 | int(p[1])<<48 | int(p[2])<<40 | int(p[3])<<32 | int(p[4])<<24 | int(p[5])<<16 | int(p[6])<<8 | int(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecIntR(p []byte, v **int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int)
}
case 8:
val := int(p[0])<<56 | int(p[1])<<48 | int(p[2])<<40 | int(p[3])<<32 | int(p[4])<<24 | int(p[5])<<16 | int(p[6])<<8 | int(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint8(p []byte, v *uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint8, the data should be in the uint8 range")
}
*v = p[7]
default:
return errWrongDataLen
}
return nil
}
func DecUint8R(p []byte, v **uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint8)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint8, the data should be in the uint8 range")
}
val := p[7]
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint16(p []byte, v *uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint16, the data should be in the uint16 range")
}
*v = uint16(p[6])<<8 | uint16(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUint16R(p []byte, v **uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint16)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint16, the data should be in the uint16 range")
}
val := uint16(p[6])<<8 | uint16(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint32(p []byte, v *uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint32, the data should be in the uint32 range")
}
*v = uint32(p[4])<<24 | uint32(p[5])<<16 | uint32(p[6])<<8 | uint32(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUint32R(p []byte, v **uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint32)
}
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into uint32, the data should be in the uint32 range")
}
val := uint32(p[4])<<24 | uint32(p[5])<<16 | uint32(p[6])<<8 | uint32(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint64(p []byte, v *uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decUint64(p)
default:
return errWrongDataLen
}
return nil
}
func DecUint64R(p []byte, v **uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint64)
}
case 8:
val := decUint64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint(p []byte, v *uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = uint(p[0])<<56 | uint(p[1])<<48 | uint(p[2])<<40 | uint(p[3])<<32 | uint(p[4])<<24 | uint(p[5])<<16 | uint(p[6])<<8 | uint(p[7])
default:
return errWrongDataLen
}
return nil
}
func DecUintR(p []byte, v **uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint)
}
case 8:
val := uint(p[0])<<56 | uint(p[1])<<48 | uint(p[2])<<40 | uint(p[3])<<32 | uint(p[4])<<24 | uint(p[5])<<16 | uint(p[6])<<8 | uint(p[7])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = ""
} else {
*v = "0"
}
case 8:
*v = strconv.FormatInt(decInt64(p), 10)
default:
return errWrongDataLen
}
return nil
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := "0"
*v = &val
}
case 8:
val := strconv.FormatInt(decInt64(p), 10)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecBigInt(p []byte, v *big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
v.SetInt64(0)
case 8:
v.SetInt64(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecBigIntR(p []byte, v **big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(big.Int)
}
case 8:
*v = big.NewInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal counter: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.Int8:
return decReflectInt8(p, v)
case reflect.Int16:
return decReflectInt16(p, v)
case reflect.Int32:
return decReflectInt32(p, v)
case reflect.Int64, reflect.Int:
return decReflectInts(p, v)
case reflect.Uint8:
return decReflectUint8(p, v)
case reflect.Uint16:
return decReflectUint16(p, v)
case reflect.Uint32:
return decReflectUint32(p, v)
case reflect.Uint64, reflect.Uint:
return decReflectUints(p, v)
case reflect.String:
return decReflectString(p, v)
default:
return fmt.Errorf("failed to unmarshal counter: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int8 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int16 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int32 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInts(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
v.SetInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint8 range", v.Interface())
}
v.SetUint(uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint16 range", v.Interface())
}
v.SetUint(uint64(p[6])<<8 | uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint32 range", v.Interface())
}
v.SetUint(uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUints(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 8:
v.SetUint(decUint64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectString(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.SetString("")
} else {
v.SetString("0")
}
case 8:
v.SetString(strconv.FormatInt(decInt64(p), 10))
default:
return errWrongDataLen
}
return nil
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal counter: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Int8:
return decReflectInt8R(p, v)
case reflect.Int16:
return decReflectInt16R(p, v)
case reflect.Int32:
return decReflectInt32R(p, v)
case reflect.Int64, reflect.Int:
return decReflectIntsR(p, v)
case reflect.Uint8:
return decReflectUint8R(p, v)
case reflect.Uint16:
return decReflectUint16R(p, v)
case reflect.Uint32:
return decReflectUint32R(p, v)
case reflect.Uint64, reflect.Uint:
return decReflectUintsR(p, v)
case reflect.String:
return decReflectStringR(p, v)
default:
return fmt.Errorf("failed to unmarshal counter: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int8 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int16 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := decInt64(p)
if val > math.MaxInt32 || val < math.MinInt32 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the int32 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectIntsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetInt(decInt64(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 || p[6] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint8 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 || p[4] != 0 || p[5] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint16 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[6])<<8 | uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
if p[0] != 0 || p[1] != 0 || p[2] != 0 || p[3] != 0 {
return fmt.Errorf("failed to unmarshal counter: to unmarshal into %T, the data should be in the uint32 range", v.Interface())
}
newVal.Elem().SetUint(uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUintsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetUint(decUint64(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectStringR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetString("0")
}
v.Elem().Set(val)
case 8:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(strconv.FormatInt(decInt64(p), 10))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectNullableR(p []byte, v reflect.Value) reflect.Value {
if p == nil {
return reflect.Zero(v.Elem().Type())
}
return reflect.New(v.Type().Elem().Elem())
}
func decInt64(p []byte) int64 {
return int64(p[0])<<56 | int64(p[1])<<48 | int64(p[2])<<40 | int64(p[3])<<32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7])
}
func decUint64(p []byte) uint64 {
return uint64(p[0])<<56 | uint64(p[1])<<48 | uint64(p[2])<<40 | uint64(p[3])<<32 | uint64(p[4])<<24 | uint64(p[5])<<16 | uint64(p[6])<<8 | uint64(p[7])
}

View File

@@ -0,0 +1,74 @@
package cqlint
import (
"math/big"
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int32:
return EncInt32(v)
case int16:
return EncInt16(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)
case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)
case big.Int:
return EncBigInt(v)
case string:
return EncString(v)
case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)
case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)
case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,249 @@
package cqlint
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
var (
maxBigInt = big.NewInt(math.MaxInt32)
minBigInt = big.NewInt(math.MinInt32)
)
func EncInt8(v int8) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, byte(v)}, nil
}
return []byte{0, 0, 0, byte(v)}, nil
}
func EncInt8R(v *int8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt8(*v)
}
func EncInt16(v int16) ([]byte, error) {
if v < 0 {
return []byte{255, 255, byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, byte(v >> 8), byte(v)}, nil
}
func EncInt16R(v *int16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt16(*v)
}
func EncInt32(v int32) ([]byte, error) {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt32(*v)
}
func EncInt64(v int64) ([]byte, error) {
if v > math.MaxInt32 || v < math.MinInt32 {
return nil, fmt.Errorf("failed to marshal int: value %#v out of range", v)
}
return encInt64(v), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncInt(v int) ([]byte, error) {
if v > math.MaxInt32 || v < math.MinInt32 {
return nil, fmt.Errorf("failed to marshal int: value %#v out of range", v)
}
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncIntR(v *int) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt(*v)
}
func EncUint8(v uint8) ([]byte, error) {
return []byte{0, 0, 0, v}, nil
}
func EncUint8R(v *uint8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint8(*v)
}
func EncUint16(v uint16) ([]byte, error) {
return []byte{0, 0, byte(v >> 8), byte(v)}, nil
}
func EncUint16R(v *uint16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint16(*v)
}
func EncUint32(v uint32) ([]byte, error) {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint32(*v)
}
func EncUint64(v uint64) ([]byte, error) {
if v > math.MaxUint32 {
return nil, fmt.Errorf("failed to marshal int: value %#v out of range", v)
}
return encUint64(v), nil
}
func EncUint64R(v *uint64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint64(*v)
}
func EncUint(v uint) ([]byte, error) {
if v > math.MaxUint32 {
return nil, fmt.Errorf("failed to marshal int: value %#v out of range", v)
}
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncUintR(v *uint) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint(*v)
}
func EncBigInt(v big.Int) ([]byte, error) {
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal int: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncBigIntR(v *big.Int) ([]byte, error) {
if v == nil {
return nil, nil
}
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal int: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
n, err := strconv.ParseInt(v, 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal int: can not marshal (%T)(%[1]v) %s", v, err)
}
return encInt64(n), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Type().Kind() {
case reflect.Int8:
val := v.Int()
if val < 0 {
return []byte{255, 255, 255, byte(val)}, nil
}
return []byte{0, 0, 0, byte(val)}, nil
case reflect.Int16:
val := v.Int()
if val < 0 {
return []byte{255, 255, byte(val >> 8), byte(val)}, nil
}
return []byte{0, 0, byte(val >> 8), byte(val)}, nil
case reflect.Int32:
return encInt64(v.Int()), nil
case reflect.Int, reflect.Int64:
val := v.Int()
if val > math.MaxInt32 || val < math.MinInt32 {
return nil, fmt.Errorf("failed to marshal int: value (%T)(%[1]v) out of range", v.Interface())
}
return encInt64(val), nil
case reflect.Uint8:
return []byte{0, 0, 0, byte(v.Uint())}, nil
case reflect.Uint16:
val := v.Uint()
return []byte{0, 0, byte(val >> 8), byte(val)}, nil
case reflect.Uint32:
return encUint64(v.Uint()), nil
case reflect.Uint, reflect.Uint64:
val := v.Uint()
if val > math.MaxUint32 {
return nil, fmt.Errorf("failed to marshal int: value (%T)(%[1]v) out of range", v.Interface())
}
return encUint64(val), nil
case reflect.String:
val := v.String()
if val == "" {
return nil, nil
}
n, err := strconv.ParseInt(val, 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal int: can not marshal (%T)(%[1]v) %s", v.Interface(), err)
}
return encInt64(n), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal int: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal int: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encUint64(v uint64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

View File

@@ -0,0 +1,81 @@
package cqlint
import (
"fmt"
"math/big"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int8:
return DecInt8(data, v)
case *int16:
return DecInt16(data, v)
case *int32:
return DecInt32(data, v)
case *int64:
return DecInt64(data, v)
case *int:
return DecInt(data, v)
case *uint8:
return DecUint8(data, v)
case *uint16:
return DecUint16(data, v)
case *uint32:
return DecUint32(data, v)
case *uint64:
return DecUint64(data, v)
case *uint:
return DecUint(data, v)
case *big.Int:
return DecBigInt(data, v)
case *string:
return DecString(data, v)
case **int8:
return DecInt8R(data, v)
case **int16:
return DecInt16R(data, v)
case **int32:
return DecInt32R(data, v)
case **int64:
return DecInt64R(data, v)
case **int:
return DecIntR(data, v)
case **uint8:
return DecUint8R(data, v)
case **uint16:
return DecUint16R(data, v)
case **uint32:
return DecUint32R(data, v)
case **uint64:
return DecUint64R(data, v)
case **uint:
return DecUintR(data, v)
case **big.Int:
return DecBigIntR(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyInt int) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal int: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,772 @@
package cqlint
import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)
const (
negInt64 = int64(-1) << 32
negInt = int(-1) << 32
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal int: the length of the data should be 0 or 4")
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal int: can not unmarshal into nil reference (%T)(%[1]v))", v)
}
func DecInt8(p []byte, v *int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
val := decInt32(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into int8, the data should be in the int8 range")
}
*v = int8(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt8R(p []byte, v **int8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int8)
}
case 4:
val := decInt32(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into int8, the data should be in the int8 range")
}
tmp := int8(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt16(p []byte, v *int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
val := decInt32(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into int16, the data should be in the int16 range")
}
*v = int16(val)
default:
return errWrongDataLen
}
return nil
}
func DecInt16R(p []byte, v **int16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int16)
}
case 4:
val := decInt32(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into int16, the data should be in the int16 range")
}
tmp := int16(val)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt32(p []byte, v *int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decInt32(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt32R(p []byte, v **int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int32)
}
case 4:
tmp := decInt32(p)
*v = &tmp
default:
return errWrongDataLen
}
return nil
}
func DecInt64(p []byte, v *int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decInt64(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt64R(p []byte, v **int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int64)
}
case 4:
val := decInt64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecInt(p []byte, v *int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decInt(p)
default:
return errWrongDataLen
}
return nil
}
func DecIntR(p []byte, v **int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int)
}
case 4:
val := decInt(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint8(p []byte, v *uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
if p[0] != 0 || p[1] != 0 || p[2] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into uint8, the data should be in the uint8 range")
}
*v = p[3]
default:
return errWrongDataLen
}
return nil
}
func DecUint8R(p []byte, v **uint8) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint8)
}
case 4:
if p[0] != 0 || p[1] != 0 || p[2] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into uint8, the data should be in the uint8 range")
}
val := p[3]
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint16(p []byte, v *uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
if p[0] != 0 || p[1] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into uint16, the data should be in the uint16 range")
}
*v = uint16(p[2])<<8 | uint16(p[3])
default:
return errWrongDataLen
}
return nil
}
func DecUint16R(p []byte, v **uint16) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint16)
}
case 4:
if p[0] != 0 || p[1] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into uint16, the data should be in the uint16 range")
}
val := uint16(p[2])<<8 | uint16(p[3])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint32(p []byte, v *uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
default:
return errWrongDataLen
}
return nil
}
func DecUint32R(p []byte, v **uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint32)
}
case 4:
val := uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint64(p []byte, v *uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decUint64(p)
default:
return errWrongDataLen
}
return nil
}
func DecUint64R(p []byte, v **uint64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint64)
}
case 4:
val := decUint64(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint(p []byte, v *uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = uint(p[0])<<24 | uint(p[1])<<16 | uint(p[2])<<8 | uint(p[3])
default:
return errWrongDataLen
}
return nil
}
func DecUintR(p []byte, v **uint) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint)
}
case 4:
val := uint(p[0])<<24 | uint(p[1])<<16 | uint(p[2])<<8 | uint(p[3])
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = ""
} else {
*v = "0"
}
case 4:
*v = strconv.FormatInt(decInt64(p), 10)
default:
return errWrongDataLen
}
return nil
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := "0"
*v = &val
}
case 4:
val := strconv.FormatInt(decInt64(p), 10)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecBigInt(p []byte, v *big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
v.SetInt64(0)
case 4:
v.SetInt64(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecBigIntR(p []byte, v **big.Int) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = big.NewInt(0)
}
case 4:
*v = big.NewInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal int: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.Int8:
return decReflectInt8(p, v)
case reflect.Int16:
return decReflectInt16(p, v)
case reflect.Int32, reflect.Int64, reflect.Int:
return decReflectInts(p, v)
case reflect.Uint8:
return decReflectUint8(p, v)
case reflect.Uint16:
return decReflectUint16(p, v)
case reflect.Uint32, reflect.Uint64, reflect.Uint:
return decReflectUints(p, v)
case reflect.String:
return decReflectString(p, v)
default:
return fmt.Errorf("failed to unmarshal int: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal int: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Int8:
return decReflectInt8R(p, v)
case reflect.Int16:
return decReflectInt16R(p, v)
case reflect.Int32, reflect.Int64, reflect.Int:
return decReflectIntsR(p, v)
case reflect.Uint8:
return decReflectUint8R(p, v)
case reflect.Uint16:
return decReflectUint16R(p, v)
case reflect.Uint32, reflect.Uint64, reflect.Uint:
return decReflectUintsR(p, v)
case reflect.String:
return decReflectStringR(p, v)
default:
return fmt.Errorf("failed to unmarshal int: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 4:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the int8 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 4:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the int16 range", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectInts(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 4:
v.SetInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 4:
if p[0] != 0 || p[1] != 0 || p[2] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the uint8 range", v.Interface())
}
v.SetUint(uint64(p[3]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 4:
if p[0] != 0 || p[1] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the uint16 range", v.Interface())
}
v.SetUint(uint64(p[2])<<8 | uint64(p[3]))
default:
return errWrongDataLen
}
return nil
}
func decReflectUints(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 4:
v.SetUint(decUint64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectString(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.SetString("")
} else {
v.SetString("0")
}
case 4:
v.SetString(strconv.FormatInt(decInt64(p), 10))
default:
return errWrongDataLen
}
return nil
}
func decReflectNullableR(p []byte, v reflect.Value) reflect.Value {
if p == nil {
return reflect.Zero(v.Elem().Type())
}
return reflect.New(v.Type().Elem().Elem())
}
func decReflectInt8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
val := decInt64(p)
if val > math.MaxInt8 || val < math.MinInt8 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the int8 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
val := decInt64(p)
if val > math.MaxInt16 || val < math.MinInt16 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the int16 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(val)
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectIntsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(decInt64(p))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint8R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
if p[0] != 0 || p[1] != 0 || p[2] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the uint8 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetUint(uint64(p[3]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint16R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
if p[0] != 0 || p[1] != 0 {
return fmt.Errorf("failed to unmarshal int: to unmarshal into (%T), the data should be in the uint16 range", v.Interface())
}
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetUint(uint64(p[2])<<8 | uint64(p[3]))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectUintsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetUint(decUint64(p))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectStringR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetString("0")
}
v.Elem().Set(val)
case 4:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(strconv.FormatInt(decInt64(p), 10))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decInt32(p []byte) int32 {
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
}
func decInt64(p []byte) int64 {
if p[0] > math.MaxInt8 {
return negInt64 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
func decInt(p []byte) int {
if p[0] > math.MaxInt8 {
return negInt | int(p[0])<<24 | int(p[1])<<16 | int(p[2])<<8 | int(p[3])
}
return int(p[0])<<24 | int(p[1])<<16 | int(p[2])<<8 | int(p[3])
}
func decUint64(p []byte) uint64 {
return uint64(p[0])<<24 | uint64(p[1])<<16 | uint64(p[2])<<8 | uint64(p[3])
}

View File

@@ -0,0 +1,30 @@
package cqltime
import (
"reflect"
"time"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int64:
return EncInt64(v)
case *int64:
return EncInt64R(v)
case time.Duration:
return EncDuration(v)
case *time.Duration:
return EncDurationR(v)
default:
// Custom types (type MyTime int64) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,76 @@
package cqltime
import (
"fmt"
"reflect"
"time"
)
const (
maxValInt64 int64 = 86399999999999
minValInt64 int64 = 0
maxValDur time.Duration = 86399999999999
minValDur time.Duration = 0
)
var (
errOutRangeInt64 = fmt.Errorf("failed to marshal time: the (int64) should be in the range 0 to 86399999999999")
errOutRangeDur = fmt.Errorf("failed to marshal time: the (time.Duration) should be in the range 0 to 86399999999999")
)
func EncInt64(v int64) ([]byte, error) {
if v > maxValInt64 || v < minValInt64 {
return nil, errOutRangeInt64
}
return encInt64(v), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncDuration(v time.Duration) ([]byte, error) {
if v > maxValDur || v < minValDur {
return nil, errOutRangeDur
}
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
func EncDurationR(v *time.Duration) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncDuration(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int64:
val := v.Int()
if val > maxValInt64 || val < minValInt64 {
return nil, fmt.Errorf("failed to marshal time: the (%T) should be in the range 0 to 86399999999999", v.Interface())
}
return encInt64(val), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal time: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal time: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}

View File

@@ -0,0 +1,36 @@
package cqltime
import (
"fmt"
"reflect"
"time"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int64:
return DecInt64(data, v)
case **int64:
return DecInt64R(data, v)
case *time.Duration:
return DecDuration(data, v)
case **time.Duration:
return DecDurationR(data, v)
default:
// Custom types (type MyTime int64) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal time: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,171 @@
package cqltime
import (
"fmt"
"reflect"
"time"
)
var (
errWrongDataLen = fmt.Errorf("failed to unmarshal time: the length of the data should be 0 or 8")
errDataOutRangeInt64 = fmt.Errorf("failed to unmarshal time: (int64) the data should be in the range 0 to 86399999999999")
errDataOutRangeDur = fmt.Errorf("failed to unmarshal time: (time.Duration) the data should be in the range 0 to 86399999999999")
)
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal time: can not unmarshal into nil reference (%T)(%[1]v))", v)
}
func DecInt64(p []byte, v *int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decInt64(p)
if *v > maxValInt64 || *v < minValInt64 {
return errDataOutRangeInt64
}
default:
return errWrongDataLen
}
return nil
}
func DecInt64R(p []byte, v **int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int64)
}
case 8:
val := decInt64(p)
if val > maxValInt64 || val < minValInt64 {
return errDataOutRangeInt64
}
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecDuration(p []byte, v *time.Duration) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 8:
*v = decDur(p)
if *v > maxValDur || *v < minValDur {
return errDataOutRangeDur
}
default:
return errWrongDataLen
}
return nil
}
func DecDurationR(p []byte, v **time.Duration) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(time.Duration)
}
case 8:
val := decDur(p)
if val > maxValDur || val < minValDur {
return errDataOutRangeDur
}
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal time: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.Int64, reflect.Int:
return decReflectInt64(p, v)
default:
return fmt.Errorf("failed to unmarshal time: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt64(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 8:
val := decInt64(p)
if val > maxValInt64 || val < minValInt64 {
return fmt.Errorf("failed to unmarshal time: (%T) the data should be in the range 0 to 86399999999999", v.Interface())
}
v.SetInt(val)
default:
return errWrongDataLen
}
return nil
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal time: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Int64, reflect.Int:
return decReflectIntsR(p, v)
default:
return fmt.Errorf("failed to unmarshal time: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectIntsR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.Elem().Set(reflect.Zero(v.Elem().Type()))
} else {
v.Elem().Set(reflect.New(v.Type().Elem().Elem()))
}
case 8:
vv := decInt64(p)
if vv > maxValInt64 || vv < minValInt64 {
return fmt.Errorf("failed to unmarshal time: (%T) the data should be in the range 0 to 86399999999999", v.Interface())
}
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetInt(vv)
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decInt64(p []byte) int64 {
return int64(p[0])<<56 | int64(p[1])<<48 | int64(p[2])<<40 | int64(p[3])<<32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7])
}
func decDur(p []byte) time.Duration {
return time.Duration(p[0])<<56 | time.Duration(p[1])<<48 | time.Duration(p[2])<<40 | time.Duration(p[3])<<32 | time.Duration(p[4])<<24 | time.Duration(p[5])<<16 | time.Duration(p[6])<<8 | time.Duration(p[7])
}

View File

@@ -0,0 +1,42 @@
package date
import (
"reflect"
"time"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case uint32:
return EncUint32(v)
case string:
return EncString(v)
case time.Time:
return EncTime(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *uint32:
return EncUint32R(v)
case *string:
return EncStringR(v)
case *time.Time:
return EncTimeR(v)
default:
// Custom types (type MyDate uint32) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,223 @@
package date
import (
"fmt"
"reflect"
"strconv"
"strings"
"time"
)
const (
millisecondsInADay int64 = 24 * 60 * 60 * 1000
centerEpoch int64 = 1 << 31
maxYear int = 5881580
minYear int = -5877641
maxMilliseconds int64 = 185542587100800000
minMilliseconds int64 = -185542587187200000
)
var (
maxDate = time.Date(5881580, 07, 11, 0, 0, 0, 0, time.UTC)
minDate = time.Date(-5877641, 06, 23, 0, 0, 0, 0, time.UTC)
)
func errWrongStringFormat(v interface{}) error {
return fmt.Errorf(`failed to marshal date: the (%T)(%[1]v) should have fromat "2006-01-02"`, v)
}
func EncInt32(v int32) ([]byte, error) {
return encInt32(v), nil
}
func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return encInt32(*v), nil
}
func EncInt64(v int64) ([]byte, error) {
if v > maxMilliseconds || v < minMilliseconds {
return nil, fmt.Errorf("failed to marshal date: the (int64)(%v) value out of range", v)
}
return encInt64(days(v)), nil
}
func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}
func EncUint32(v uint32) ([]byte, error) {
return encUint32(v), nil
}
func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return encUint32(*v), nil
}
func EncTime(v time.Time) ([]byte, error) {
if v.After(maxDate) || v.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%s) value should be in the range from -5877641-06-23 to 5881580-07-11", v, v.Format("2006-01-02"))
}
return encTime(v), nil
}
func EncTimeR(v *time.Time) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncTime(*v)
}
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
var err error
var y, m, d int
var t time.Time
switch ps := strings.Split(v, "-"); len(ps) {
case 3:
if y, err = strconv.Atoi(ps[0]); err != nil {
return nil, errWrongStringFormat(v)
}
if m, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v)
}
if d, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v)
}
case 4:
if y, err = strconv.Atoi(ps[1]); err != nil || ps[0] != "" {
return nil, errWrongStringFormat(v)
}
y = -y
if m, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v)
}
if d, err = strconv.Atoi(ps[3]); err != nil {
return nil, errWrongStringFormat(v)
}
default:
return nil, errWrongStringFormat(v)
}
if y > maxYear || y < minYear {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v)
}
t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC)
if t.After(maxDate) || t.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v)
}
return encTime(t), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int32:
return encInt64(v.Int()), nil
case reflect.Int64:
val := v.Int()
if val > maxMilliseconds || val < minMilliseconds {
return nil, fmt.Errorf("failed to marshal date: the value (%T)(%[1]v) out of range", v.Interface())
}
return encInt64(days(val)), nil
case reflect.Uint32:
val := v.Uint()
return []byte{byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val)}, nil
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal date: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if val == "" {
return nil, nil
}
var err error
var y, m, d int
var t time.Time
ps := strings.Split(val, "-")
switch len(ps) {
case 3:
if y, err = strconv.Atoi(ps[0]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if m, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if d, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
case 4:
if y, err = strconv.Atoi(ps[1]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
y = -y
if m, err = strconv.Atoi(ps[2]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
if d, err = strconv.Atoi(ps[3]); err != nil {
return nil, errWrongStringFormat(v.Interface())
}
default:
return nil, errWrongStringFormat(v.Interface())
}
if y > maxYear || y < minYear {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface())
}
t = time.Date(y, time.Month(m), d, 0, 0, 0, 0, time.UTC)
if t.After(maxDate) || t.Before(minDate) {
return nil, fmt.Errorf("failed to marshal date: the (%T)(%[1]v) value should be in the range from -5877641-06-23 to 5881580-07-11", v.Interface())
}
return encTime(t), nil
}
func encInt64(v int64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encInt32(v int32) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encUint32(v uint32) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encTime(v time.Time) []byte {
d := days(v.UnixMilli())
return []byte{byte(d >> 24), byte(d >> 16), byte(d >> 8), byte(d)}
}
func days(v int64) int64 {
return v/millisecondsInADay + centerEpoch
}

View File

@@ -0,0 +1,49 @@
package date
import (
"fmt"
"reflect"
"time"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *int32:
return DecInt32(data, v)
case *int64:
return DecInt64(data, v)
case *uint32:
return DecUint32(data, v)
case *string:
return DecString(data, v)
case *time.Time:
return DecTime(data, v)
case **int32:
return DecInt32R(data, v)
case **int64:
return DecInt64R(data, v)
case **uint32:
return DecUint32R(data, v)
case **string:
return DecStringR(data, v)
case **time.Time:
return DecTimeR(data, v)
default:
// Custom types (type MyDate uint32) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,401 @@
package date
import (
"fmt"
"math"
"reflect"
"time"
)
const (
negInt64 = int64(-1) << 32
zeroDate = "-5877641-06-23"
zeroMS int64 = -185542587187200000
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal date: the length of the data should be 0 or 4")
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v)
}
func DecInt32(p []byte, v *int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decInt32(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt32R(p []byte, v **int32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(int32)
}
case 4:
val := decInt32(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecInt64(p []byte, v *int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = zeroMS
case 4:
*v = decMilliseconds(p)
default:
return errWrongDataLen
}
return nil
}
func DecInt64R(p []byte, v **int64) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := zeroMS
*v = &val
}
case 4:
val := decMilliseconds(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecUint32(p []byte, v *uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = 0
case 4:
*v = decUint32(p)
default:
return errWrongDataLen
}
return nil
}
func DecUint32R(p []byte, v **uint32) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = new(uint32)
}
case 4:
val := decUint32(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = ""
} else {
*v = zeroDate
}
case 4:
*v = decString(p)
default:
return errWrongDataLen
}
return nil
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := zeroDate
*v = &val
}
case 4:
val := decString(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecTime(p []byte, v *time.Time) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
*v = minDate
case 4:
*v = decTime(p)
default:
return errWrongDataLen
}
return nil
}
func DecTimeR(p []byte, v **time.Time) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
val := minDate
*v = &val
}
case 4:
val := decTime(p)
*v = &val
default:
return errWrongDataLen
}
return nil
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v))", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.Int32:
return decReflectInt32(p, v)
case reflect.Int64:
return decReflectInt64(p, v)
case reflect.Uint32:
return decReflectUint32(p, v)
case reflect.String:
return decReflectString(p, v)
default:
return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(0)
case 4:
v.SetInt(decInt64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectInt64(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetInt(zeroMS)
case 4:
v.SetInt(decMilliseconds(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.SetUint(0)
case 4:
v.SetUint(decUint64(p))
default:
return errWrongDataLen
}
return nil
}
func decReflectString(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.SetString("")
} else {
v.SetString(zeroDate)
}
case 4:
v.SetString(decString(p))
default:
return errWrongDataLen
}
return nil
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal date: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.Int32:
return decReflectInt32R(p, v)
case reflect.Int64:
return decReflectInt64R(p, v)
case reflect.Uint32:
return decReflectUint32R(p, v)
case reflect.String:
return decReflectStringR(p, v)
default:
return fmt.Errorf("failed to unmarshal date: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectInt32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetInt(decInt64(p))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectInt64R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetInt(zeroMS)
v.Elem().Set(val)
}
v.Elem().Set(val)
case 4:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetInt(decMilliseconds(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectUint32R(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
v.Elem().Set(decReflectNullableR(p, v))
case 4:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetUint(decUint64(p))
v.Elem().Set(newVal)
default:
return errWrongDataLen
}
return nil
}
func decReflectStringR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(zeroDate)
}
v.Elem().Set(val)
case 4:
val := reflect.New(v.Type().Elem().Elem())
val.Elem().SetString(decString(p))
v.Elem().Set(val)
default:
return errWrongDataLen
}
return nil
}
func decReflectNullableR(p []byte, v reflect.Value) reflect.Value {
if p == nil {
return reflect.Zero(v.Elem().Type())
}
return reflect.New(v.Type().Elem().Elem())
}
func decInt32(p []byte) int32 {
return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3])
}
func decInt64(p []byte) int64 {
if p[0] > math.MaxInt8 {
return negInt64 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
func decMilliseconds(p []byte) int64 {
return (int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3]) - centerEpoch) * millisecondsInADay
}
func decUint32(p []byte) uint32 {
return uint32(p[0])<<24 | uint32(p[1])<<16 | uint32(p[2])<<8 | uint32(p[3])
}
func decUint64(p []byte) uint64 {
return uint64(p[0])<<24 | uint64(p[1])<<16 | uint64(p[2])<<8 | uint64(p[3])
}
func decString(p []byte) string {
return decTime(p).Format("2006-01-02")
}
func decTime(p []byte) time.Time {
return time.UnixMilli(decMilliseconds(p)).UTC()
}

View File

@@ -0,0 +1,29 @@
package decimal
import (
"gopkg.in/inf.v0"
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case inf.Dec:
return EncInfDec(v)
case *inf.Dec:
return EncInfDecR(v)
case string:
return EncString(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyString string) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,141 @@
package decimal
import (
"fmt"
"gopkg.in/inf.v0"
"math/big"
"reflect"
"strconv"
"strings"
"github.com/gocql/gocql/serialization/varint"
)
func EncInfDec(v inf.Dec) ([]byte, error) {
sign := v.Sign()
if sign == 0 {
return []byte{0, 0, 0, 0, 0}, nil
}
return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...), nil
}
func EncInfDecR(v *inf.Dec) ([]byte, error) {
if v == nil {
return nil, nil
}
return encInfDecR(v), nil
}
// EncString encodes decimal string which should contains `scale` and `unscaled` strings separated by `;`.
func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}
vs := strings.Split(v, ";")
if len(vs) != 2 {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string %s", v)
}
scale, err := strconv.ParseInt(vs[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string %s", vs[0])
}
unscaleData, err := encUnscaledString(vs[1])
if err != nil {
return nil, err
}
return append(encScale64(scale), unscaleData...), nil
}
func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Type().Kind() {
case reflect.String:
return encReflectString(v)
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal decimal: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encReflectString(v reflect.Value) ([]byte, error) {
val := v.String()
if val == "" {
return nil, nil
}
vs := strings.Split(val, ";")
if len(vs) != 2 {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal string (%T)(%[1]v)", v.Interface())
}
scale, err := strconv.ParseInt(vs[0], 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid decimal scale string (%T)(%s)", v.Interface(), vs[0])
}
unscaledData, err := encUnscaledString(vs[1])
if err != nil {
return nil, err
}
return append(encScale64(scale), unscaledData...), nil
}
func encInfDecR(v *inf.Dec) []byte {
sign := v.Sign()
if sign == 0 {
return []byte{0, 0, 0, 0, 0}
}
return append(encScale(v.Scale()), varint.EncBigIntRS(v.UnscaledBig())...)
}
func encScale(v inf.Scale) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encScale64(v int64) []byte {
return []byte{byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func encUnscaledString(v string) ([]byte, error) {
switch {
case len(v) == 0:
return nil, nil
case len(v) <= 18:
n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s, %s", v, err)
}
return varint.EncInt64Ext(n), nil
case len(v) <= 20:
n, err := strconv.ParseInt(v, 10, 64)
if err == nil {
return varint.EncInt64Ext(n), nil
}
t, ok := new(big.Int).SetString(v, 10)
if !ok {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v)
}
return varint.EncBigIntRS(t), nil
default:
t, ok := new(big.Int).SetString(v, 10)
if !ok {
return nil, fmt.Errorf("failed to marshal decimal: invalid unscaled string %s", v)
}
return varint.EncBigIntRS(t), nil
}
}

View File

@@ -0,0 +1,34 @@
package decimal
import (
"fmt"
"gopkg.in/inf.v0"
"reflect"
)
func Unmarshal(data []byte, value interface{}) error {
switch v := value.(type) {
case nil:
return nil
case *inf.Dec:
return DecInfDec(data, v)
case **inf.Dec:
return DecInfDecR(data, v)
case *string:
return DecString(data, v)
case **string:
return DecStringR(data, v)
default:
// Custom types (type MyString string) can be deserialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.ValueOf(value)
rt := rv.Type()
if rt.Kind() != reflect.Ptr {
return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%#[1]v)", value)
}
if rt.Elem().Kind() != reflect.Ptr {
return DecReflect(data, rv)
}
return DecReflectR(data, rv)
}
}

View File

@@ -0,0 +1,80 @@
package decimal
import (
"gopkg.in/inf.v0"
)
const (
neg8 = int64(-1) << 8
neg16 = int64(-1) << 16
neg24 = int64(-1) << 24
neg32 = int64(-1) << 32
neg40 = int64(-1) << 40
neg48 = int64(-1) << 48
neg56 = int64(-1) << 56
neg32Int = int(-1) << 32
)
func decScale(p []byte) inf.Scale {
return inf.Scale(p[0])<<24 | inf.Scale(p[1])<<16 | inf.Scale(p[2])<<8 | inf.Scale(p[3])
}
func decScaleInt64(p []byte) int64 {
if p[0] > 127 {
return neg32 | int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
return int64(p[0])<<24 | int64(p[1])<<16 | int64(p[2])<<8 | int64(p[3])
}
func dec1toInt64(p []byte) int64 {
if p[4] > 127 {
return neg8 | int64(p[4])
}
return int64(p[4])
}
func dec2toInt64(p []byte) int64 {
if p[4] > 127 {
return neg16 | int64(p[4])<<8 | int64(p[5])
}
return int64(p[4])<<8 | int64(p[5])
}
func dec3toInt64(p []byte) int64 {
if p[4] > 127 {
return neg24 | int64(p[4])<<16 | int64(p[5])<<8 | int64(p[6])
}
return int64(p[4])<<16 | int64(p[5])<<8 | int64(p[6])
}
func dec4toInt64(p []byte) int64 {
if p[4] > 127 {
return neg32 | int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7])
}
return int64(p[4])<<24 | int64(p[5])<<16 | int64(p[6])<<8 | int64(p[7])
}
func dec5toInt64(p []byte) int64 {
if p[4] > 127 {
return neg40 | int64(p[4])<<32 | int64(p[5])<<24 | int64(p[6])<<16 | int64(p[7])<<8 | int64(p[8])
}
return int64(p[4])<<32 | int64(p[5])<<24 | int64(p[6])<<16 | int64(p[7])<<8 | int64(p[8])
}
func dec6toInt64(p []byte) int64 {
if p[4] > 127 {
return neg48 | int64(p[4])<<40 | int64(p[5])<<32 | int64(p[6])<<24 | int64(p[7])<<16 | int64(p[8])<<8 | int64(p[9])
}
return int64(p[4])<<40 | int64(p[5])<<32 | int64(p[6])<<24 | int64(p[7])<<16 | int64(p[8])<<8 | int64(p[9])
}
func dec7toInt64(p []byte) int64 {
if p[4] > 127 {
return neg56 | int64(p[4])<<48 | int64(p[5])<<40 | int64(p[6])<<32 | int64(p[7])<<24 | int64(p[8])<<16 | int64(p[9])<<8 | int64(p[10])
}
return int64(p[4])<<48 | int64(p[5])<<40 | int64(p[6])<<32 | int64(p[7])<<24 | int64(p[8])<<16 | int64(p[9])<<8 | int64(p[10])
}
func dec8toInt64(p []byte) int64 {
return int64(p[4])<<56 | int64(p[5])<<48 | int64(p[6])<<40 | int64(p[7])<<32 | int64(p[8])<<24 | int64(p[9])<<16 | int64(p[10])<<8 | int64(p[11])
}

View File

@@ -0,0 +1,323 @@
package decimal
import (
"fmt"
"gopkg.in/inf.v0"
"reflect"
"strconv"
"github.com/gocql/gocql/serialization/varint"
)
var errWrongDataLen = fmt.Errorf("failed to unmarshal decimal: the length of the data should be 0 or more than 5")
func errBrokenData(p []byte) error {
if p[4] == 0 && p[5] <= 127 || p[4] == 255 && p[5] > 127 {
return fmt.Errorf("failed to unmarshal decimal: the data is broken")
}
return nil
}
func errNilReference(v interface{}) error {
return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference(%T)(%[1]v)", v)
}
func DecInfDec(p []byte, v *inf.Dec) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
v.SetScale(0).SetUnscaled(0)
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
v.SetScale(decScale(p)).SetUnscaled(dec1toInt64(p))
return nil
case 6:
v.SetScale(decScale(p)).SetUnscaled(dec2toInt64(p))
case 7:
v.SetScale(decScale(p)).SetUnscaled(dec3toInt64(p))
case 8:
v.SetScale(decScale(p)).SetUnscaled(dec4toInt64(p))
case 9:
v.SetScale(decScale(p)).SetUnscaled(dec5toInt64(p))
case 10:
v.SetScale(decScale(p)).SetUnscaled(dec6toInt64(p))
case 11:
v.SetScale(decScale(p)).SetUnscaled(dec7toInt64(p))
case 12:
v.SetScale(decScale(p)).SetUnscaled(dec8toInt64(p))
default:
v.SetScale(decScale(p)).SetUnscaledBig(varint.Dec2BigInt(p[4:]))
}
return errBrokenData(p)
}
func DecInfDecR(p []byte, v **inf.Dec) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
*v = inf.NewDec(0, 0)
}
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
*v = inf.NewDec(dec1toInt64(p), decScale(p))
return nil
case 6:
*v = inf.NewDec(dec2toInt64(p), decScale(p))
case 7:
*v = inf.NewDec(dec3toInt64(p), decScale(p))
case 8:
*v = inf.NewDec(dec4toInt64(p), decScale(p))
case 9:
*v = inf.NewDec(dec5toInt64(p), decScale(p))
case 10:
*v = inf.NewDec(dec6toInt64(p), decScale(p))
case 11:
*v = inf.NewDec(dec7toInt64(p), decScale(p))
case 12:
*v = inf.NewDec(dec8toInt64(p), decScale(p))
default:
*v = inf.NewDecBig(varint.Dec2BigInt(p[4:]), decScale(p))
}
return errBrokenData(p)
}
func DecString(p []byte, v *string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = ""
} else {
*v = "0;0"
}
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
*v = decString5(p)
return nil
case 6:
*v = decString6(p)
case 7:
*v = decString7(p)
case 8:
*v = decString8(p)
case 9:
*v = decString9(p)
case 10:
*v = decString10(p)
case 11:
*v = decString11(p)
case 12:
*v = decString12(p)
default:
*v = decString(p)
}
return errBrokenData(p)
}
func DecStringR(p []byte, v **string) error {
if v == nil {
return errNilReference(v)
}
switch len(p) {
case 0:
if p == nil {
*v = nil
} else {
tmp := "0;0"
*v = &tmp
}
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
tmp := decString5(p)
*v = &tmp
return nil
case 6:
tmp := decString6(p)
*v = &tmp
case 7:
tmp := decString7(p)
*v = &tmp
case 8:
tmp := decString8(p)
*v = &tmp
case 9:
tmp := decString9(p)
*v = &tmp
case 10:
tmp := decString10(p)
*v = &tmp
case 11:
tmp := decString11(p)
*v = &tmp
case 12:
tmp := decString12(p)
*v = &tmp
default:
tmp := decString(p)
*v = &tmp
}
return errBrokenData(p)
}
func DecReflect(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference (%T)(%#[1]v)", v.Interface())
}
switch v = v.Elem(); v.Kind() {
case reflect.String:
return decReflectString(p, v)
default:
return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%#[1]v)", v.Interface())
}
}
func DecReflectR(p []byte, v reflect.Value) error {
if v.IsNil() {
return fmt.Errorf("failed to unmarshal decimal: can not unmarshal into nil reference (%T)(%[1]v)", v.Interface())
}
switch v.Type().Elem().Elem().Kind() {
case reflect.String:
return decReflectStringR(p, v)
default:
return fmt.Errorf("failed to unmarshal decimal: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func decReflectString(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
if p == nil {
v.SetString("")
} else {
v.SetString("0;0")
}
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
v.SetString(decString5(p))
return nil
case 6:
v.SetString(decString6(p))
case 7:
v.SetString(decString7(p))
case 8:
v.SetString(decString8(p))
case 9:
v.SetString(decString9(p))
case 10:
v.SetString(decString10(p))
case 11:
v.SetString(decString11(p))
case 12:
v.SetString(decString12(p))
default:
v.SetString(decString(p))
}
return errBrokenData(p)
}
func decReflectStringR(p []byte, v reflect.Value) error {
switch len(p) {
case 0:
var val reflect.Value
if p == nil {
val = reflect.Zero(v.Type().Elem())
} else {
val = reflect.New(v.Type().Elem().Elem())
val.Elem().SetString("0;0")
}
v.Elem().Set(val)
return nil
case 1, 2, 3, 4:
return errWrongDataLen
case 5:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString5(p))
v.Elem().Set(newVal)
return nil
case 6:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString6(p))
v.Elem().Set(newVal)
case 7:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString7(p))
v.Elem().Set(newVal)
case 8:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString8(p))
v.Elem().Set(newVal)
case 9:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString9(p))
v.Elem().Set(newVal)
case 10:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString10(p))
v.Elem().Set(newVal)
case 11:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString11(p))
v.Elem().Set(newVal)
case 12:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString12(p))
v.Elem().Set(newVal)
default:
newVal := reflect.New(v.Type().Elem().Elem())
newVal.Elem().SetString(decString(p))
v.Elem().Set(newVal)
}
return errBrokenData(p)
}
func decString5(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec1toInt64(p), 10)
}
func decString6(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec2toInt64(p), 10)
}
func decString7(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec3toInt64(p), 10)
}
func decString8(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec4toInt64(p), 10)
}
func decString9(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec5toInt64(p), 10)
}
func decString10(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec6toInt64(p), 10)
}
func decString11(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec7toInt64(p), 10)
}
func decString12(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + strconv.FormatInt(dec8toInt64(p), 10)
}
func decString(p []byte) string {
return strconv.FormatInt(decScaleInt64(p), 10) + ";" + varint.Dec2BigInt(p[4:]).String()
}

View File

@@ -0,0 +1,24 @@
package double
import (
"reflect"
)
func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case float64:
return EncFloat64(v)
case *float64:
return EncFloat64R(v)
default:
// Custom types (type MyFloat float64) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}

View File

@@ -0,0 +1,59 @@
package double
import (
"fmt"
"reflect"
"unsafe"
)
func EncFloat64(v float64) ([]byte, error) {
return encFloat64(v), nil
}
func EncFloat64R(v *float64) ([]byte, error) {
if v == nil {
return nil, nil
}
return encFloat64R(v), nil
}
func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Float64:
return encFloat64(v.Float()), nil
case reflect.Struct:
if v.Type().String() == "gocql.unsetColumn" {
return nil, nil
}
return nil, fmt.Errorf("failed to marshal double: unsupported value type (%T)(%[1]v)", v.Interface())
default:
return nil, fmt.Errorf("failed to marshal double: unsupported value type (%T)(%[1]v)", v.Interface())
}
}
func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}
func encFloat64(v float64) []byte {
return encUint64(floatToUint(v))
}
func encFloat64R(v *float64) []byte {
return encUint64(floatToUintR(v))
}
func encUint64(v uint64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
func floatToUint(v float64) uint64 {
return *(*uint64)(unsafe.Pointer(&v))
}
func floatToUintR(v *float64) uint64 {
return *(*uint64)(unsafe.Pointer(v))
}

Some files were not shown because too many files have changed in this diff Show More