flex_array: allow 0 length elements

flex_arrays are supposed to be a replacement for:
kmalloc(num_elements * sizeof(element))

If kmalloc is given 0 num_elements or a 0 size element it will happily return
ZERO_SIZE_PTR.  Which looks like a valid allocation, but which will explode if
something actually try to use it.  The current flex_array code will return an
equivalent result if num_elements is 0, but will fail to work if
sizeof(element) is 0.  This patch allows allocation to work even for 0 size
elements.  It will cause flex_arrays to explode though if they are used.
Imitating the kmalloc behavior.

Based-on-patch-by: Steffen Klassert <steffen.klassert@secunet.com>
Signed-off-by: Eric Paris <eparis@redhat.com>
Acked-by: Dave Hansen <dave@linux.vnet.ibm.com>
This commit is contained in:
Eric Paris 2011-04-28 15:55:52 -04:00
parent 150cdf6ec0
commit a8d05c81fb

View file

@ -88,8 +88,11 @@ struct flex_array *flex_array_alloc(int element_size, unsigned int total,
gfp_t flags) gfp_t flags)
{ {
struct flex_array *ret; struct flex_array *ret;
int max_size = FLEX_ARRAY_NR_BASE_PTRS * int max_size = 0;
FLEX_ARRAY_ELEMENTS_PER_PART(element_size);
if (element_size)
max_size = FLEX_ARRAY_NR_BASE_PTRS *
FLEX_ARRAY_ELEMENTS_PER_PART(element_size);
/* max_size will end up 0 if element_size > PAGE_SIZE */ /* max_size will end up 0 if element_size > PAGE_SIZE */
if (total > max_size) if (total > max_size)
@ -183,15 +186,18 @@ __fa_get_part(struct flex_array *fa, int part_nr, gfp_t flags)
int flex_array_put(struct flex_array *fa, unsigned int element_nr, void *src, int flex_array_put(struct flex_array *fa, unsigned int element_nr, void *src,
gfp_t flags) gfp_t flags)
{ {
int part_nr = fa_element_to_part_nr(fa, element_nr); int part_nr;
struct flex_array_part *part; struct flex_array_part *part;
void *dst; void *dst;
if (element_nr >= fa->total_nr_elements) if (element_nr >= fa->total_nr_elements)
return -ENOSPC; return -ENOSPC;
if (!fa->element_size)
return 0;
if (elements_fit_in_base(fa)) if (elements_fit_in_base(fa))
part = (struct flex_array_part *)&fa->parts[0]; part = (struct flex_array_part *)&fa->parts[0];
else { else {
part_nr = fa_element_to_part_nr(fa, element_nr);
part = __fa_get_part(fa, part_nr, flags); part = __fa_get_part(fa, part_nr, flags);
if (!part) if (!part)
return -ENOMEM; return -ENOMEM;
@ -211,15 +217,18 @@ EXPORT_SYMBOL(flex_array_put);
*/ */
int flex_array_clear(struct flex_array *fa, unsigned int element_nr) int flex_array_clear(struct flex_array *fa, unsigned int element_nr)
{ {
int part_nr = fa_element_to_part_nr(fa, element_nr); int part_nr;
struct flex_array_part *part; struct flex_array_part *part;
void *dst; void *dst;
if (element_nr >= fa->total_nr_elements) if (element_nr >= fa->total_nr_elements)
return -ENOSPC; return -ENOSPC;
if (!fa->element_size)
return 0;
if (elements_fit_in_base(fa)) if (elements_fit_in_base(fa))
part = (struct flex_array_part *)&fa->parts[0]; part = (struct flex_array_part *)&fa->parts[0];
else { else {
part_nr = fa_element_to_part_nr(fa, element_nr);
part = fa->parts[part_nr]; part = fa->parts[part_nr];
if (!part) if (!part)
return -EINVAL; return -EINVAL;
@ -264,6 +273,8 @@ int flex_array_prealloc(struct flex_array *fa, unsigned int start,
if (end >= fa->total_nr_elements) if (end >= fa->total_nr_elements)
return -ENOSPC; return -ENOSPC;
if (!fa->element_size)
return 0;
if (elements_fit_in_base(fa)) if (elements_fit_in_base(fa))
return 0; return 0;
start_part = fa_element_to_part_nr(fa, start); start_part = fa_element_to_part_nr(fa, start);
@ -291,14 +302,17 @@ EXPORT_SYMBOL(flex_array_prealloc);
*/ */
void *flex_array_get(struct flex_array *fa, unsigned int element_nr) void *flex_array_get(struct flex_array *fa, unsigned int element_nr)
{ {
int part_nr = fa_element_to_part_nr(fa, element_nr); int part_nr;
struct flex_array_part *part; struct flex_array_part *part;
if (!fa->element_size)
return NULL;
if (element_nr >= fa->total_nr_elements) if (element_nr >= fa->total_nr_elements)
return NULL; return NULL;
if (elements_fit_in_base(fa)) if (elements_fit_in_base(fa))
part = (struct flex_array_part *)&fa->parts[0]; part = (struct flex_array_part *)&fa->parts[0];
else { else {
part_nr = fa_element_to_part_nr(fa, element_nr);
part = fa->parts[part_nr]; part = fa->parts[part_nr];
if (!part) if (!part)
return NULL; return NULL;
@ -353,7 +367,7 @@ int flex_array_shrink(struct flex_array *fa)
int part_nr; int part_nr;
int ret = 0; int ret = 0;
if (!fa->total_nr_elements) if (!fa->total_nr_elements || !fa->element_size)
return 0; return 0;
if (elements_fit_in_base(fa)) if (elements_fit_in_base(fa))
return ret; return ret;