#include <netcdf.h>
#include <udunits.h>
#include <uderrmsg.h>
#include <stdio.h>
#include <limits.h>

#ifndef UNITS_NAME
#define UNITS_NAME	"units"	/* netCDF name of units attribute */
#endif

#ifndef FILL_NAME
#define FILL_NAME	"_FillValue"
#endif

/*
 * Initialize udunits library.  Reads in units table.  Returns a non-zero value
 * on failure, 0 on success.  Only tries to read in units table once; if that
 * fails, it will continue to return failure code on subsequent calls.
 *
 * If the environment variable UDUNITS_TABLE exists, that is used as the path
 * of the units table, otherwise whatever was compiled into the installed
 * udape library is used.
 */
static int
init_udunits() {
    extern char *getenv();
    static int first = 1;
    static int success = 0;

    if (first) {
	char *path = getenv("UDUNITS_TABLE");

	(void) uderrmode(UD_VERBOSE); /* turn off error fatality */
 	success = utInit(path);
	first = 0;
    }
    return success;
}

/*
  This routine calculates the slope and intercept necessary to convert from 
  the units of the given netCDF variable to the specified units.  If units
  is blank, the netCDF units are returned and the slope is set to one and 
  the intercept to 0.
  */

int
nc_units(ncid, varid, units, slope, intercept)

int ncid, varid;
char *units;
double *slope, *intercept;
{
    int ulen;
    char nc_units[100];
    utUnit funits, tunits;

    *slope = 1.0; *intercept = 0.0;
    if (units == NULL) return(0);
    if (ncattinq(ncid, varid, UNITS_NAME, NULL, &ulen) == -1 ||
	ncattget(ncid, varid, UNITS_NAME, nc_units) == -1) {
	fprintf(stderr, "Cannot get data units\n");
	return(1);
    }
    nc_units[ulen] = '\0';
/*
  Prepare for possible units conversion.
  */
    if (units[0]) {
/*
  Prepare udunits.
  */
	if (init_udunits()) {
	    fprintf(stderr, "Cannot initialize udunits library\n");
	    return(1);
	}
	if (utScan(units, &tunits) == 0 
	    && utScan(nc_units, &funits) == 0) 
	    utConvert(&funits, &tunits, slope, intercept);
    }
    else if (units) {
/*
  No units specified - return units in the netCDF file.
  */
	strcpy(units, nc_units);
    }
    return(0);
}

/*
  This routine is an interface to ncvarget that convert the data from 
  the netCDF to floating point with the given missing value and units.
*/

int
nc_float(ncid, varid, corn, edge, data, slope, intercept, missing)

int ncid, varid;
long *corn, *edge;
float *data, missing;
double slope, intercept;

{
    int ndims, i, prod=1;
    nc_type datatype;
    ncvoid *dd;
    char *c, *cf;
    short *s, *sf;
    long *l, *lf;
    float *ff;
    double *d, miss;
    
    if (ncvarinq(ncid, varid, NULL, &datatype, &ndims, NULL, NULL) == -1) {
	fprintf(stderr, "Cannot get data type\n");
	return(1);
    }
    for (i=0; i < ndims; ++i) prod *= edge[i];
/*
  See if a temporary array is necessary.
  */
    if (datatype == NC_FLOAT) dd = data;
    else {
	switch (datatype) {

	case NC_CHAR: 
	case NC_BYTE:
	    i = prod*sizeof(char);
	    break;
	case NC_SHORT:
	    i = prod*sizeof(short);
	    break;
	case NC_LONG:
	    i = prod*sizeof(long);
	    break;
	case NC_DOUBLE:
	    i = prod*sizeof(double);
	}
	if ((dd = (ncvoid *) malloc(i))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
    }
    if (ncvarget(ncid, varid, corn, edge, dd) == -1) {
	if (dd != data) free(dd);
	fprintf(stderr, "ncvarget failed\n");
	return(1);
    }
/*
  Handle missing data (if present).
  */
    if (ncattget(ncid, varid, FILL_NAME, &miss) == -1)
	switch (datatype) {
	case NC_CHAR:
	case NC_BYTE:
	    c = (char *) dd;
	    for (i=0; i < prod; ++i) data[i] = c[i];
	    break;

	case NC_SHORT:
	    s = (short *) dd;
	    for (i=0; i < prod; ++i) data[i] = s[i];
	    break;

	case NC_LONG:
	    l = (long *) dd;
	    for (i=0; i < prod; ++i) data[i] = l[i];
	    break;

	case NC_DOUBLE:
	    d = (double *) dd;
	    for (i=0; i < prod; ++i) data[i] = d[i];
	    break;
	}

    else switch (datatype) {
	    
	case NC_FLOAT:	/* Just change missing value */
	    ff = (float *) &miss;
	    for (i=0; i < prod; ++i) {
		if (data[i] == *ff) data[i] = missing;
	    }
	    break;

	case NC_CHAR: 
	case NC_BYTE:
	    cf = (char *) &miss;
	    c = (char *) dd;
	    for (i=0; i < prod; ++i) {
		if (c[i] == *cf) data[i] = missing;
		else data[i] = c[i];
	    }
	    break;
	    
	case NC_SHORT:
	    sf = (short *) &miss;
	    s = (short *) dd;
	    for (i=0; i < prod; ++i) {
		if (s[i] == *sf) data[i] = missing;
		else data[i] = s[i];
	    }
	    break;
	    
	case NC_LONG:
	    lf = (long *) &miss;
	    l = (long *) dd;
	    for (i=0; i < prod; ++i) {
		if (l[i] == *lf) data[i] = missing;
		else data[i] = l[i];
	    }
	    break;
	    
	case NC_DOUBLE:
	    
	    d = (double *) dd;
	    for (i=0; i < prod; ++i) {
		if (d[i] == miss) data[i] = missing;
		else data[i] = d[i];
	    }
	}

    if (dd != data) free(dd);

    if (slope != 1.0) {
	if (intercept != 0.) {
	    for(i=0; i<prod; i++)
		if (data[i] != missing)
		    data[i] = slope * data[i] + intercept;
	}
	else
	    for(i=0; i<prod; i++)
		if (data[i] != missing)
		    data[i] *= slope;
    }
    else if (intercept != 0.)
	for (i=0; i < prod; i++)
	    if (data[i] != missing)
		data[i] += intercept;
    return(0);
}
/*
  This routine is an interface to ncvarput that convert the data to
  the netCDF from floating point with the given missing value and units.
*/

int
float_nc(ncid, varid, corn, edge, data, slope, intercept, missing)

int ncid, varid;
long *corn, *edge;
float *data, missing;
double slope, intercept;

{
    int ndims, i, prod=1, rangerr=0, check_miss=0;
    nc_type datatype;
    ncvoid *dd;
    char *c, *cf;
    short *s, *sf;
    long *l, *lf;
    float *f, *ff, *datac;
    double *d, miss;
    
    if (ncvarinq(ncid, varid, NULL, &datatype, &ndims, NULL, NULL) == -1) {
	fprintf(stderr, "Cannot get data type\n");
	return(1);
    }
    for (i=0; i < ndims; ++i) prod *= edge[i];
/*
  Convert the units.
  */
    if (slope == 1.0 && intercept == 0.0) datac = data;
    else {
	datac = (float *) malloc(prod * sizeof(float));
	if (slope != 1.0) {
	    if (intercept != 0.) {
		for(i=0; i<prod; i++) {
		    if (data[i] == missing) datac[i] = data[i];
		    else datac[i] = slope * data[i] + intercept;
		}
	    }
	    else
		for(i=0; i<prod; i++) {
		    if (data[i] == missing) datac[i] = data[i];
		    else datac[i] = data[i] * slope;
	    }
	}
	else {
	    for (i=0; i < prod; i++) {
		if (data[i] == missing) datac[i] = data[i];
		else datac[i] = data[i] + intercept;
	    }
	}
    }
	
/*
  Create temporary array and store
  */
    if (ncattget(ncid, varid, FILL_NAME, &miss) == -1) check_miss++;
    switch (datatype) {

    case NC_CHAR: 
    case NC_BYTE:
	cf = (char *) &miss;
	if ((c=(char *) malloc(prod*sizeof(char)))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
	for (i=0; i < prod; ++i) {
	    if (check_miss && datac[i] == missing) c[i] = *cf;
	    else if (data[i] < CHAR_MIN || data[i] > CHAR_MAX) {
		rangerr = 1;
		c[i] = *cf;
	    }
	    else c[i] = nint(datac[i]);
	}
	dd = (ncvoid *) c;
	break;
    case NC_SHORT:
	sf = (short *) &miss;
	if ((s=(short *) malloc(prod*sizeof(short)))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
	for (i=0; i < prod; ++i) {
	    if (check_miss && datac[i] == missing) s[i] = *sf;
	    else if (data[i] < SHRT_MIN || data[i] > SHRT_MAX) {
		rangerr = 1;
		s[i] = *sf;
	    }
	    else s[i] = nint(datac[i]);
	}
	dd = (ncvoid *) s;
	break;
    case NC_LONG:
	lf = (long *) &miss;
	if ((l=(long *) malloc(prod*sizeof(long)))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
	for (i=0; i < prod; ++i) {
	    if (check_miss && datac[i] == missing) l[i] = *lf;
	    else if (data[i] < LONG_MIN || data[i] > LONG_MAX) {
		rangerr = 1;
		l[i] = *lf;
	    }
	    else l[i] = nint(datac[i]);
	}
	dd = (ncvoid *) l;
	break;
    case NC_FLOAT:
	ff = (float *) &miss;
	if (!check_miss || missing == *ff) {
	    dd = (ncvoid *) datac;
	    break;
	}
	if ((f=(float *) malloc(prod*sizeof(float)))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
	for (i=0; i < prod; ++i) {
	    if (datac[i] == missing) f[i] = *ff;
	    else f[i] = datac[i];
	}
	dd = (ncvoid *) f;
	break;
    case NC_DOUBLE:
	if ((d=(double *) malloc(prod*sizeof(double)))==NULL) {
	    fprintf(stderr, "out of memory\n");
	    return(1);
	}
	for (i=0; i < prod; ++i) {
	    if (check_miss && datac[i] == missing) d[i] = miss;
	    else d[i] = datac[i];
	}
	dd = (ncvoid *) d;
	break;
    }
    if (data!=datac && (ncvoid *)datac!=dd) free(datac);
    if (rangerr)
	fprintf(stderr, "Range error - out of bounds data set to missing\n");
    i = ncvarput(ncid, varid, corn, edge, dd);
    if (dd != (ncvoid *) data) free(dd);
    if (i == -1) {
	fprintf(stderr, "ncvarput failed\n");
	return(1);
    }
    return(0);
}

