RPC: clean up the RPCSEC_GSS kerberos and spkm3 context import functions Signed-off-by: Trond Myklebust --- include/linux/sunrpc/gss_api.h | 10 +- net/sunrpc/auth_gss/auth_gss.c | 2 net/sunrpc/auth_gss/gss_krb5_mech.c | 105 +++++++++++++++----------- net/sunrpc/auth_gss/gss_mech_switch.c | 6 - net/sunrpc/auth_gss/gss_spkm3_mech.c | 133 +++++++++++++++++----------------- net/sunrpc/auth_gss/svcauth_gss.c | 5 - 6 files changed, 141 insertions(+), 120 deletions(-) Index: linux-2.6.11-rc2/include/linux/sunrpc/gss_api.h =================================================================== --- linux-2.6.11-rc2.orig/include/linux/sunrpc/gss_api.h +++ linux-2.6.11-rc2/include/linux/sunrpc/gss_api.h @@ -33,8 +33,9 @@ struct gss_ctx { /* gss-api prototypes; note that these are somewhat simplified versions of * the prototypes specified in RFC 2744. */ -u32 gss_import_sec_context( - struct xdr_netobj *input_token, +int gss_import_sec_context( + const void* input_token, + size_t bufsize, struct gss_api_mech *mech, struct gss_ctx **ctx_id); u32 gss_get_mic( @@ -80,8 +81,9 @@ struct gss_api_mech { /* and must provide the following operations: */ struct gss_api_ops { - u32 (*gss_import_sec_context)( - struct xdr_netobj *input_token, + int (*gss_import_sec_context)( + const void *input_token, + size_t bufsize, struct gss_ctx *ctx_id); u32 (*gss_get_mic)( struct gss_ctx *ctx_id, Index: linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_spkm3_mech.c =================================================================== --- linux-2.6.11-rc2.orig/net/sunrpc/auth_gss/gss_spkm3_mech.c +++ linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_spkm3_mech.c @@ -49,52 +49,51 @@ # define RPCDBG_FACILITY RPCDBG_AUTH #endif -static inline int -get_bytes(char **ptr, const char *end, void *res, int len) +static const void * +simple_get_bytes(const void *p, const void *end, void *res, int len) { - char *p, *q; - p = *ptr; - q = p + len; - if (q > end || q < p) - return -1; + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); memcpy(res, p, len); - *ptr = q; - return 0; + return q; } -static inline int -get_netobj(char **ptr, const char *end, struct xdr_netobj *res) +static const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) { - char *p, *q; - p = *ptr; - if (get_bytes(&p, end, &res->len, sizeof(res->len))) - return -1; - q = p + res->len; - if(res->len == 0) - goto out_nocopy; - if (q > end || q < p) - return -1; - if (!(res->data = kmalloc(res->len, GFP_KERNEL))) - return -1; - memcpy(res->data, p, res->len); -out_nocopy: - *ptr = q; - return 0; + const void *q; + unsigned int len; + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + res->len = len; + if (len == 0) { + res->data = NULL; + return p; + } + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + res->data = kmalloc(len, GFP_KERNEL); + if (unlikely(res->data == NULL)) + return ERR_PTR(-ENOMEM); + memcpy(res->data, p, len); + return q; } -static inline int -get_key(char **p, char *end, struct crypto_tfm **res, int *resalg) +static inline const void * +get_key(const void *p, const void *end, struct crypto_tfm **res, int *resalg) { - struct xdr_netobj key = { - .len = 0, - .data = NULL, - }; + struct xdr_netobj key = { 0 }; int alg_mode,setkey = 0; char *alg_name; - if (get_bytes(p, end, resalg, sizeof(int))) + p = simple_get_bytes(p, end, resalg, sizeof(*resalg)); + if (IS_ERR(p)) goto out_err; - if ((get_netobj(p, end, &key))) + p = simple_get_netobj(p, end, &key); + if (IS_ERR(p)) goto out_err; switch (*resalg) { @@ -111,10 +110,6 @@ get_key(char **p, char *end, struct cryp alg_mode = 0; setkey = 0; break; - case NID_cast5_cbc: - dprintk("RPC: SPKM3 get_key: case cast5_cbc, UNSUPPORTED \n"); - goto out_err; - break; default: dprintk("RPC: SPKM3 get_key: unsupported algorithm %d", *resalg); goto out_err_free_key; @@ -128,69 +123,81 @@ get_key(char **p, char *end, struct cryp if(key.len > 0) kfree(key.data); - return 0; + return p; out_err_free_tfm: crypto_free_tfm(*res); out_err_free_key: if(key.len > 0) kfree(key.data); + p = ERR_PTR(-EINVAL); out_err: - return -1; + return p; } -static u32 -gss_import_sec_context_spkm3(struct xdr_netobj *inbuf, +static int +gss_import_sec_context_spkm3(const void *p, size_t len, struct gss_ctx *ctx_id) { - char *p = inbuf->data; - char *end = inbuf->data + inbuf->len; + const void *end = (const void *)((const char *)p + len); struct spkm3_ctx *ctx; if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL))) goto out_err; memset(ctx, 0, sizeof(*ctx)); - if (get_netobj(&p, end, &ctx->ctx_id)) + p = simple_get_netobj(p, end, &ctx->ctx_id); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->qop, sizeof(ctx->qop))) + p = simple_get_bytes(p, end, &ctx->qop, sizeof(ctx->qop)); + if (IS_ERR(p)) goto out_err_free_ctx_id; - if (get_netobj(&p, end, &ctx->mech_used)) + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) goto out_err_free_mech; - if (get_bytes(&p, end, &ctx->ret_flags, sizeof(ctx->ret_flags))) + p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags)); + if (IS_ERR(p)) goto out_err_free_mech; - if (get_bytes(&p, end, &ctx->req_flags, sizeof(ctx->req_flags))) + p = simple_get_bytes(p, end, &ctx->req_flags, sizeof(ctx->req_flags)); + if (IS_ERR(p)) goto out_err_free_mech; - if (get_netobj(&p, end, &ctx->share_key)) + p = simple_get_netobj(p, end, &ctx->share_key); + if (IS_ERR(p)) goto out_err_free_s_key; - if (get_key(&p, end, &ctx->derived_conf_key, &ctx->conf_alg)) { - dprintk("RPC: SPKM3 confidentiality key will be NULL\n"); - } - - if (get_key(&p, end, &ctx->derived_integ_key, &ctx->intg_alg)) { - dprintk("RPC: SPKM3 integrity key will be NULL\n"); - } - - if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg))) + p = get_key(p, end, &ctx->derived_conf_key, &ctx->conf_alg); + if (IS_ERR(p)) goto out_err_free_s_key; - if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg))) - goto out_err_free_s_key; + p = get_key(p, end, &ctx->derived_integ_key, &ctx->intg_alg); + if (IS_ERR(p)) + goto out_err_free_key1; + + p = simple_get_bytes(p, end, &ctx->keyestb_alg, sizeof(ctx->keyestb_alg)); + if (IS_ERR(p)) + goto out_err_free_key2; + + p = simple_get_bytes(p, end, &ctx->owf_alg, sizeof(ctx->owf_alg)); + if (IS_ERR(p)) + goto out_err_free_key2; if (p != end) - goto out_err_free_s_key; + goto out_err_free_key2; ctx_id->internal_ctx_id = ctx; dprintk("Succesfully imported new spkm context.\n"); return 0; +out_err_free_key2: + crypto_free_tfm(ctx->derived_integ_key); +out_err_free_key1: + crypto_free_tfm(ctx->derived_conf_key); out_err_free_s_key: kfree(ctx->share_key.data); out_err_free_mech: @@ -200,7 +207,7 @@ out_err_free_ctx_id: out_err_free_ctx: kfree(ctx); out_err: - return GSS_S_FAILURE; + return PTR_ERR(p); } static void Index: linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_krb5_mech.c =================================================================== --- linux-2.6.11-rc2.orig/net/sunrpc/auth_gss/gss_krb5_mech.c +++ linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -48,46 +48,48 @@ # define RPCDBG_FACILITY RPCDBG_AUTH #endif -static inline int -get_bytes(char **ptr, const char *end, void *res, int len) +static const void * +simple_get_bytes(const void *p, const void *end, void *res, int len) { - char *p, *q; - p = *ptr; - q = p + len; - if (q > end || q < p) - return -1; + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); memcpy(res, p, len); - *ptr = q; - return 0; + return q; } -static inline int -get_netobj(char **ptr, const char *end, struct xdr_netobj *res) +static const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) { - char *p, *q; - p = *ptr; - if (get_bytes(&p, end, &res->len, sizeof(res->len))) - return -1; - q = p + res->len; - if (q > end || q < p) - return -1; - if (!(res->data = kmalloc(res->len, GFP_KERNEL))) - return -1; - memcpy(res->data, p, res->len); - *ptr = q; - return 0; + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + res->data = kmalloc(len, GFP_KERNEL); + if (unlikely(res->data == NULL)) + return ERR_PTR(-ENOMEM); + memcpy(res->data, p, len); + res->len = len; + return q; } -static inline int -get_key(char **p, char *end, struct crypto_tfm **res) +static inline const void * +get_key(const void *p, const void *end, struct crypto_tfm **res) { struct xdr_netobj key; int alg, alg_mode; char *alg_name; - if (get_bytes(p, end, &alg, sizeof(alg))) + p = simple_get_bytes(p, end, &alg, sizeof(alg)); + if (IS_ERR(p)) goto out_err; - if ((get_netobj(p, end, &key))) + p = simple_get_netobj(p, end, &key); + if (IS_ERR(p)) goto out_err; switch (alg) { @@ -105,50 +107,63 @@ get_key(char **p, char *end, struct cryp goto out_err_free_tfm; kfree(key.data); - return 0; + return p; out_err_free_tfm: crypto_free_tfm(*res); out_err_free_key: kfree(key.data); + p = ERR_PTR(-EINVAL); out_err: - return -1; + return p; } -static u32 -gss_import_sec_context_kerberos(struct xdr_netobj *inbuf, +static int +gss_import_sec_context_kerberos(const void *p, + size_t len, struct gss_ctx *ctx_id) { - char *p = inbuf->data; - char *end = inbuf->data + inbuf->len; + const void *end = (const void *)((const char *)p + len); struct krb5_ctx *ctx; if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL))) goto out_err; memset(ctx, 0, sizeof(*ctx)); - if (get_bytes(&p, end, &ctx->initiate, sizeof(ctx->initiate))) + p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->seed_init, sizeof(ctx->seed_init))) + p = simple_get_bytes(p, end, &ctx->seed_init, sizeof(ctx->seed_init)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, ctx->seed, sizeof(ctx->seed))) + p = simple_get_bytes(p, end, ctx->seed, sizeof(ctx->seed)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->signalg, sizeof(ctx->signalg))) + p = simple_get_bytes(p, end, &ctx->signalg, sizeof(ctx->signalg)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->sealalg, sizeof(ctx->sealalg))) + p = simple_get_bytes(p, end, &ctx->sealalg, sizeof(ctx->sealalg)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->endtime, sizeof(ctx->endtime))) + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_bytes(&p, end, &ctx->seq_send, sizeof(ctx->seq_send))) + p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_netobj(&p, end, &ctx->mech_used)) + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) goto out_err_free_ctx; - if (get_key(&p, end, &ctx->enc)) + p = get_key(p, end, &ctx->enc); + if (IS_ERR(p)) goto out_err_free_mech; - if (get_key(&p, end, &ctx->seq)) + p = get_key(p, end, &ctx->seq); + if (IS_ERR(p)) goto out_err_free_key1; - if (p != end) + if (p != end) { + p = ERR_PTR(-EFAULT); goto out_err_free_key2; + } ctx_id->internal_ctx_id = ctx; dprintk("RPC: Succesfully imported new context.\n"); @@ -163,7 +178,7 @@ out_err_free_mech: out_err_free_ctx: kfree(ctx); out_err: - return GSS_S_FAILURE; + return PTR_ERR(p); } static void Index: linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_mech_switch.c =================================================================== --- linux-2.6.11-rc2.orig/net/sunrpc/auth_gss/gss_mech_switch.c +++ linux-2.6.11-rc2/net/sunrpc/auth_gss/gss_mech_switch.c @@ -233,8 +233,8 @@ EXPORT_SYMBOL(gss_mech_put); /* The mech could probably be determined from the token instead, but it's just * as easy for now to pass it in. */ -u32 -gss_import_sec_context(struct xdr_netobj *input_token, +int +gss_import_sec_context(const void *input_token, size_t bufsize, struct gss_api_mech *mech, struct gss_ctx **ctx_id) { @@ -244,7 +244,7 @@ gss_import_sec_context(struct xdr_netobj (*ctx_id)->mech_type = gss_mech_get(mech); return mech->gm_ops - ->gss_import_sec_context(input_token, *ctx_id); + ->gss_import_sec_context(input_token, bufsize, *ctx_id); } /* gss_get_mic: compute a mic over message and return mic_token. */ Index: linux-2.6.11-rc2/net/sunrpc/auth_gss/auth_gss.c =================================================================== --- linux-2.6.11-rc2.orig/net/sunrpc/auth_gss/auth_gss.c +++ linux-2.6.11-rc2/net/sunrpc/auth_gss/auth_gss.c @@ -272,7 +272,7 @@ gss_parse_init_downcall(struct gss_api_m goto err_free_wire_ctx; if (p != end) goto err_free_wire_ctx; - if (gss_import_sec_context(&tmp_buf, gm, &ctx->gc_gss_ctx)) + if (gss_import_sec_context(tmp_buf.data, tmp_buf.len, gm, &ctx->gc_gss_ctx)) goto err_free_wire_ctx; *gc = ctx; return 0; Index: linux-2.6.11-rc2/net/sunrpc/auth_gss/svcauth_gss.c =================================================================== --- linux-2.6.11-rc2.orig/net/sunrpc/auth_gss/svcauth_gss.c +++ linux-2.6.11-rc2/net/sunrpc/auth_gss/svcauth_gss.c @@ -381,7 +381,6 @@ static int rsc_parse(struct cache_detail else { int N, i; struct gss_api_mech *gm; - struct xdr_netobj tmp_buf; /* gid */ if (get_int(&mesg, &rsci.cred.cr_gid)) @@ -420,9 +419,7 @@ static int rsc_parse(struct cache_detail gss_mech_put(gm); goto out; } - tmp_buf.len = len; - tmp_buf.data = buf; - if (gss_import_sec_context(&tmp_buf, gm, &rsci.mechctx)) { + if (gss_import_sec_context(buf, len, gm, &rsci.mechctx)) { gss_mech_put(gm); goto out; }