RPC: Initialize the GSS context upon RPC credential creation. Signed-off-by: Trond Myklebust --- auth.c | 7 + auth_gss/auth_gss.c | 211 ++++++++++++++++++++++++++++++++++------------------ 2 files changed, 145 insertions(+), 73 deletions(-) Index: linux-2.6.11/net/sunrpc/auth_gss/auth_gss.c =================================================================== --- linux-2.6.11.orig/net/sunrpc/auth_gss/auth_gss.c +++ linux-2.6.11/net/sunrpc/auth_gss/auth_gss.c @@ -87,6 +87,7 @@ struct gss_auth { struct gss_api_mech *mech; enum rpc_gss_svc service; struct list_head upcalls; + struct rpc_clnt *client; struct dentry *dentry; char path[48]; spinlock_t lock; @@ -296,7 +297,8 @@ struct gss_upcall_msg { struct rpc_pipe_msg msg; struct list_head list; struct gss_auth *auth; - struct rpc_wait_queue waitq; + struct rpc_wait_queue rpc_waitqueue; + wait_queue_head_t waitqueue; struct gss_cl_ctx *ctx; }; @@ -326,16 +328,34 @@ __gss_find_upcall(struct gss_auth *gss_a return NULL; } +/* Try to add a upcall to the pipefs queue. + * If an upcall owned by our uid already exists, then we return a reference + * to that upcall instead of adding the new upcall. + */ +static inline struct gss_upcall_msg * +gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg) +{ + struct gss_upcall_msg *old; + + spin_lock(&gss_auth->lock); + old = __gss_find_upcall(gss_auth, gss_msg->uid); + if (old == NULL) { + atomic_inc(&gss_msg->count); + list_add(&gss_msg->list, &gss_auth->upcalls); + } else + gss_msg = old; + spin_unlock(&gss_auth->lock); + return gss_msg; +} + static void __gss_unhash_msg(struct gss_upcall_msg *gss_msg) { if (list_empty(&gss_msg->list)) return; list_del_init(&gss_msg->list); - if (gss_msg->msg.errno < 0) - rpc_wake_up_status(&gss_msg->waitq, gss_msg->msg.errno); - else - rpc_wake_up(&gss_msg->waitq); + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + wake_up_all(&gss_msg->waitqueue); atomic_dec(&gss_msg->count); } @@ -361,81 +381,127 @@ gss_upcall_callback(struct rpc_task *tas gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_get_ctx(gss_msg->ctx)); else task->tk_status = gss_msg->msg.errno; + spin_lock(&gss_msg->auth->lock); gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + spin_unlock(&gss_msg->auth->lock); gss_release_msg(gss_msg); } -static int -gss_upcall(struct rpc_clnt *clnt, struct rpc_task *task, struct rpc_cred *cred) +static inline struct gss_upcall_msg * +gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid) +{ + struct gss_upcall_msg *gss_msg; + + gss_msg = kmalloc(sizeof(*gss_msg), GFP_KERNEL); + if (gss_msg != NULL) { + memset(gss_msg, 0, sizeof(*gss_msg)); + INIT_LIST_HEAD(&gss_msg->list); + rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); + init_waitqueue_head(&gss_msg->waitqueue); + atomic_set(&gss_msg->count, 1); + gss_msg->msg.data = &gss_msg->uid; + gss_msg->msg.len = sizeof(gss_msg->uid); + gss_msg->uid = uid; + gss_msg->auth = gss_auth; + } + return gss_msg; +} + +static struct gss_upcall_msg * +gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred) { - struct gss_auth *gss_auth = container_of(clnt->cl_auth, + struct gss_upcall_msg *gss_new, *gss_msg; + + gss_new = gss_alloc_msg(gss_auth, cred->cr_uid); + if (gss_new == NULL) + return ERR_PTR(-ENOMEM); + gss_msg = gss_add_msg(gss_auth, gss_new); + if (gss_msg == gss_new) { + int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg); + if (res) { + gss_unhash_msg(gss_new); + gss_msg = ERR_PTR(res); + } + } else + gss_release_msg(gss_new); + return gss_msg; +} + +static inline int +gss_refresh_upcall(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_auth *gss_auth = container_of(task->tk_client->cl_auth, struct gss_auth, rpc_auth); struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); - struct gss_upcall_msg *gss_msg, *gss_new = NULL; - struct rpc_pipe_msg *msg; - struct dentry *dentry = gss_auth->dentry; - uid_t uid = cred->cr_uid; - int res = 0; - - dprintk("RPC: %4u gss_upcall for uid %u\n", task->tk_pid, uid); + struct gss_upcall_msg *gss_msg; + int err = 0; -retry: - spin_lock(&gss_auth->lock); - gss_msg = __gss_find_upcall(gss_auth, uid); - if (gss_msg) - goto out_sleep; - if (gss_new == NULL) { - spin_unlock(&gss_auth->lock); - gss_new = kmalloc(sizeof(*gss_new), GFP_KERNEL); - if (!gss_new) { - dprintk("RPC: %4u gss_upcall -ENOMEM\n", task->tk_pid); - return -ENOMEM; - } - goto retry; + dprintk("RPC: %4u gss_refresh_upcall for uid %u\n", task->tk_pid, cred->cr_uid); + gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred); + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; } - gss_msg = gss_new; - memset(gss_new, 0, sizeof(*gss_new)); - INIT_LIST_HEAD(&gss_new->list); - rpc_init_wait_queue(&gss_new->waitq, "RPCSEC_GSS upcall waitq"); - atomic_set(&gss_new->count, 2); - msg = &gss_new->msg; - msg->data = &gss_new->uid; - msg->len = sizeof(gss_new->uid); - gss_new->uid = uid; - gss_new->auth = gss_auth; - list_add(&gss_new->list, &gss_auth->upcalls); - gss_new = NULL; - /* Has someone updated the credential behind our back? */ - if (!gss_cred_is_uptodate_ctx(cred)) { - /* No, so do upcall and sleep */ + spin_lock(&gss_auth->lock); + if (gss_cred->gc_upcall != NULL) + rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL, NULL); + else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) { task->tk_timeout = 0; - /* gss_upcall_callback will release the reference to gss_msg */ gss_cred->gc_upcall = gss_msg; - rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL); - spin_unlock(&gss_auth->lock); - res = rpc_queue_upcall(dentry->d_inode, msg); - if (res) - gss_unhash_msg(gss_msg); - } else { - /* Yes, so cancel upcall */ - __gss_unhash_msg(gss_msg); + /* gss_upcall_callback will release the reference to gss_upcall_msg */ + atomic_inc(&gss_msg->count); + rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback, NULL); + } else + err = gss_msg->msg.errno; + spin_unlock(&gss_auth->lock); + gss_release_msg(gss_msg); +out: + dprintk("RPC: %4u gss_refresh_upcall for uid %u result %d\n", task->tk_pid, + cred->cr_uid, err); + return err; +} + +static inline int +gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct rpc_cred *cred = &gss_cred->gc_base; + struct gss_upcall_msg *gss_msg; + DEFINE_WAIT(wait); + int err = 0; + + dprintk("RPC: gss_upcall for uid %u\n", cred->cr_uid); + gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred); + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + for (;;) { + prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE); + spin_lock(&gss_auth->lock); + if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) { + spin_unlock(&gss_auth->lock); + break; + } spin_unlock(&gss_auth->lock); - gss_release_msg(gss_msg); + if (signalled()) { + err = -ERESTARTSYS; + goto out_intr; + } + schedule(); } - dprintk("RPC: %4u gss_upcall for uid %u result %d\n", task->tk_pid, - uid, res); - return res; -out_sleep: - task->tk_timeout = 0; - /* gss_upcall_callback will release the reference to gss_msg */ - gss_cred->gc_upcall = gss_msg; - rpc_sleep_on(&gss_msg->waitq, task, gss_upcall_callback, NULL); - spin_unlock(&gss_auth->lock); - dprintk("RPC: %4u gss_upcall sleeping\n", task->tk_pid); - if (gss_new) - kfree(gss_new); - return 0; + if (gss_msg->ctx) + gss_cred_set_ctx(cred, gss_get_ctx(gss_msg->ctx)); + else + err = gss_msg->msg.errno; +out_intr: + finish_wait(&gss_msg->waitqueue, &wait); + gss_release_msg(gss_msg); +out: + dprintk("RPC: gss_create_upcall for uid %u result %d\n", cred->cr_uid, err); + return err; } static ssize_t @@ -604,6 +670,7 @@ gss_create(struct rpc_clnt *clnt, rpc_au return NULL; if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) goto out_dec; + gss_auth->client = clnt; gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); if (!gss_auth->mech) { printk(KERN_WARNING "%s: Pseudoflavor %d not found!", @@ -701,6 +768,7 @@ gss_create_cred(struct rpc_auth *auth, s { struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); struct gss_cred *cred = NULL; + int err = -ENOMEM; dprintk("RPC: gss_create_cred for uid %d, flavor %d\n", acred->uid, auth->au_flavor); @@ -718,11 +786,14 @@ gss_create_cred(struct rpc_auth *auth, s cred->gc_flags = 0; cred->gc_base.cr_ops = &gss_credops; cred->gc_service = gss_auth->service; + err = gss_create_upcall(gss_auth, cred); + if (err < 0) + goto out_err; return &cred->gc_base; out_err: - dprintk("RPC: gss_create_cred failed\n"); + dprintk("RPC: gss_create_cred failed with error %d\n", err); if (cred) gss_destroy_cred(&cred->gc_base); return NULL; } @@ -808,11 +879,9 @@ out_put_ctx: static int gss_refresh(struct rpc_task *task) { - struct rpc_clnt *clnt = task->tk_client; - struct rpc_cred *cred = task->tk_msg.rpc_cred; - if (!gss_cred_is_uptodate_ctx(cred)) - return gss_upcall(clnt, task, cred); + if (!gss_cred_is_uptodate_ctx(task->tk_msg.rpc_cred)) + return gss_refresh_upcall(task); return 0; } Index: linux-2.6.11/net/sunrpc/auth.c =================================================================== --- linux-2.6.11.orig/net/sunrpc/auth.c +++ linux-2.6.11/net/sunrpc/auth.c @@ -365,11 +365,14 @@ rpcauth_refreshcred(struct rpc_task *tas { struct rpc_auth *auth = task->tk_auth; struct rpc_cred *cred = task->tk_msg.rpc_cred; + int err; dprintk("RPC: %4d refreshing %s cred %p\n", task->tk_pid, auth->au_ops->au_name, cred); - task->tk_status = cred->cr_ops->crrefresh(task); - return task->tk_status; + err = cred->cr_ops->crrefresh(task); + if (err < 0) + task->tk_status = err; + return err; } void