public class GeneralizedLinearRegression extends Predictor<FeaturesType,Learner,M> implements DefaultParamsWritable, Logging
Fit a Generalized Linear Model (see Generalized linear model (Wikipedia)) specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for each family is listed below. The first link function of each family is the default one. - "gaussian" : "identity", "log", "inverse" - "binomial" : "logit", "probit", "cloglog" - "poisson" : "log", "identity", "sqrt" - "gamma" : "inverse", "identity", "log" - "tweedie" : power link function specified through "linkPower". The default link power in the tweedie family is 1 - variancePower.
Modifier and Type | Class and Description |
---|---|
static class |
GeneralizedLinearRegression.Binomial$
Binomial exponential family distribution.
|
static class |
GeneralizedLinearRegression.CLogLog$ |
static class |
GeneralizedLinearRegression.Family$ |
static class |
GeneralizedLinearRegression.FamilyAndLink$ |
static class |
GeneralizedLinearRegression.Gamma$
Gamma exponential family distribution.
|
static class |
GeneralizedLinearRegression.Gaussian$
Gaussian exponential family distribution.
|
static class |
GeneralizedLinearRegression.Identity$ |
static class |
GeneralizedLinearRegression.Inverse$ |
static class |
GeneralizedLinearRegression.Link$ |
static class |
GeneralizedLinearRegression.Log$ |
static class |
GeneralizedLinearRegression.Logit$ |
static class |
GeneralizedLinearRegression.Poisson$
Poisson exponential family distribution.
|
static class |
GeneralizedLinearRegression.Probit$ |
static class |
GeneralizedLinearRegression.Sqrt$ |
static class |
GeneralizedLinearRegression.Tweedie$ |
Constructor and Description |
---|
GeneralizedLinearRegression() |
GeneralizedLinearRegression(String uid) |
Modifier and Type | Method and Description |
---|---|
static Params |
clear(Param<?> param) |
GeneralizedLinearRegression |
copy(ParamMap extra)
Creates a copy of this instance with the same UID and some extra params.
|
static String |
explainParam(Param<?> param) |
static String |
explainParams() |
static ParamMap |
extractParamMap() |
static ParamMap |
extractParamMap(ParamMap extra) |
static Param<String> |
family() |
Param<String> |
family()
Param for the name of family which is a description of the error distribution
to be used in the model.
|
static Param<String> |
featuresCol() |
Param<String> |
featuresCol()
Param for features column name.
|
static M |
fit(Dataset<?> dataset) |
static M |
fit(Dataset<?> dataset,
ParamMap paramMap) |
static scala.collection.Seq<M> |
fit(Dataset<?> dataset,
ParamMap[] paramMaps) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
ParamPair<?>... otherParamPairs) |
static M |
fit(Dataset<?> dataset,
ParamPair<?> firstParamPair,
scala.collection.Seq<ParamPair<?>> otherParamPairs) |
static BooleanParam |
fitIntercept() |
static <T> scala.Option<T> |
get(Param<T> param) |
static <T> scala.Option<T> |
getDefault(Param<T> param) |
static String |
getFamily() |
String |
getFamily() |
static String |
getFeaturesCol() |
String |
getFeaturesCol() |
static boolean |
getFitIntercept() |
static String |
getLabelCol() |
String |
getLabelCol() |
static String |
getLink() |
String |
getLink() |
static double |
getLinkPower() |
double |
getLinkPower() |
static String |
getLinkPredictionCol() |
String |
getLinkPredictionCol() |
static int |
getMaxIter() |
static <T> T |
getOrDefault(Param<T> param) |
static Param<Object> |
getParam(String paramName) |
static String |
getPredictionCol() |
String |
getPredictionCol() |
static double |
getRegParam() |
static String |
getSolver() |
static double |
getTol() |
static double |
getVariancePower() |
double |
getVariancePower() |
static String |
getWeightCol() |
static <T> boolean |
hasDefault(Param<T> param) |
boolean |
hasLinkPredictionCol()
Checks whether we should output link prediction.
|
static boolean |
hasParam(String paramName) |
static boolean |
isDefined(Param<?> param) |
static boolean |
isSet(Param<?> param) |
static Param<String> |
labelCol() |
Param<String> |
labelCol()
Param for label column name.
|
static Param<String> |
link() |
Param<String> |
link()
Param for the name of link function which provides the relationship
between the linear predictor and the mean of the distribution function.
|
static DoubleParam |
linkPower() |
DoubleParam |
linkPower()
Param for the index in the power link function.
|
static Param<String> |
linkPredictionCol() |
Param<String> |
linkPredictionCol()
Param for link prediction (linear predictor) column name.
|
static GeneralizedLinearRegression |
load(String path) |
static IntParam |
maxIter() |
static Param<?>[] |
params() |
static Param<String> |
predictionCol() |
Param<String> |
predictionCol()
Param for prediction column name.
|
static DoubleParam |
regParam() |
static void |
save(String path) |
static <T> Params |
set(Param<T> param,
T value) |
GeneralizedLinearRegression |
setFamily(String value)
Sets the value of param
family . |
static Learner |
setFeaturesCol(String value) |
GeneralizedLinearRegression |
setFitIntercept(boolean value)
Sets if we should fit the intercept.
|
static Learner |
setLabelCol(String value) |
GeneralizedLinearRegression |
setLink(String value)
Sets the value of param
link . |
GeneralizedLinearRegression |
setLinkPower(double value)
Sets the value of param
linkPower . |
GeneralizedLinearRegression |
setLinkPredictionCol(String value)
Sets the link prediction (linear predictor) column name.
|
GeneralizedLinearRegression |
setMaxIter(int value)
Sets the maximum number of iterations (applicable for solver "irls").
|
static Learner |
setPredictionCol(String value) |
GeneralizedLinearRegression |
setRegParam(double value)
Sets the regularization parameter for L2 regularization.
|
GeneralizedLinearRegression |
setSolver(String value)
Sets the solver algorithm used for optimization.
|
GeneralizedLinearRegression |
setTol(double value)
Sets the convergence tolerance of iterations.
|
GeneralizedLinearRegression |
setVariancePower(double value)
Sets the value of param
variancePower . |
GeneralizedLinearRegression |
setWeightCol(String value)
Sets the value of param
weightCol . |
static Param<String> |
solver() |
static DoubleParam |
tol() |
static String |
toString() |
static StructType |
transformSchema(StructType schema) |
String |
uid()
An immutable unique ID for the object and its derivatives.
|
static StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType) |
StructType |
validateAndTransformSchema(StructType schema,
boolean fitting,
DataType featuresDataType)
Validates and transforms the input schema with the provided param map.
|
static DoubleParam |
variancePower() |
DoubleParam |
variancePower()
Param for the power in the variance function of the Tweedie distribution which provides
the relationship between the variance and mean of the distribution.
|
static Param<String> |
weightCol() |
static MLWriter |
write() |
fit, setFeaturesCol, setLabelCol, setPredictionCol, transformSchema
equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
initializeLogging, initializeLogIfNecessary, isTraceEnabled, log_, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning
write
save
clear, copyValues, defaultCopy, defaultParamMap, explainParam, explainParams, extractParamMap, extractParamMap, get, getDefault, getOrDefault, getParam, hasDefault, hasParam, isDefined, isSet, paramMap, params, set, set, set, setDefault, setDefault, shouldOwn
toString
public GeneralizedLinearRegression(String uid)
public GeneralizedLinearRegression()
public static GeneralizedLinearRegression load(String path)
public static String toString()
public static Param<?>[] params()
public static String explainParam(Param<?> param)
public static String explainParams()
public static final boolean isSet(Param<?> param)
public static final boolean isDefined(Param<?> param)
public static boolean hasParam(String paramName)
public static Param<Object> getParam(String paramName)
public static final <T> scala.Option<T> get(Param<T> param)
public static final <T> T getOrDefault(Param<T> param)
public static final <T> scala.Option<T> getDefault(Param<T> param)
public static final <T> boolean hasDefault(Param<T> param)
public static final ParamMap extractParamMap()
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, scala.collection.Seq<ParamPair<?>> otherParamPairs)
public static M fit(Dataset<?> dataset, ParamPair<?> firstParamPair, ParamPair<?>... otherParamPairs)
public static final Param<String> labelCol()
public static final String getLabelCol()
public static final Param<String> featuresCol()
public static final String getFeaturesCol()
public static final Param<String> predictionCol()
public static final String getPredictionCol()
public static Learner setLabelCol(String value)
public static Learner setFeaturesCol(String value)
public static Learner setPredictionCol(String value)
public static M fit(Dataset<?> dataset)
public static StructType transformSchema(StructType schema)
public static final BooleanParam fitIntercept()
public static final boolean getFitIntercept()
public static final IntParam maxIter()
public static final int getMaxIter()
public static final DoubleParam tol()
public static final double getTol()
public static final DoubleParam regParam()
public static final double getRegParam()
public static final Param<String> weightCol()
public static final String getWeightCol()
public static final Param<String> solver()
public static final String getSolver()
public static final Param<String> family()
public static String getFamily()
public static final DoubleParam variancePower()
public static double getVariancePower()
public static final Param<String> link()
public static String getLink()
public static final DoubleParam linkPower()
public static double getLinkPower()
public static final Param<String> linkPredictionCol()
public static String getLinkPredictionCol()
public static StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public static void save(String path) throws java.io.IOException
java.io.IOException
public static MLWriter write()
public String uid()
Identifiable
uid
in interface Identifiable
public GeneralizedLinearRegression setFamily(String value)
family
.
Default is "gaussian".
value
- (undocumented)public GeneralizedLinearRegression setVariancePower(double value)
variancePower
.
Used only when family is "tweedie".
Default is 0.0, which corresponds to the "gaussian" family.
value
- (undocumented)public GeneralizedLinearRegression setLinkPower(double value)
linkPower
.
Used only when family is "tweedie".
value
- (undocumented)public GeneralizedLinearRegression setLink(String value)
link
.
Used only when family is not "tweedie".
value
- (undocumented)public GeneralizedLinearRegression setFitIntercept(boolean value)
value
- (undocumented)public GeneralizedLinearRegression setMaxIter(int value)
value
- (undocumented)public GeneralizedLinearRegression setTol(double value)
value
- (undocumented)public GeneralizedLinearRegression setRegParam(double value)
$$ 0.5 * regParam * L2norm(coefficients)^2 $$Default is 0.0.
value
- (undocumented)public GeneralizedLinearRegression setWeightCol(String value)
weightCol
.
If this is not set or empty, we treat all instance weights as 1.0.
Default is not set, so all instances have weight one.
In the Binomial family, weights correspond to number of trials and should be integer.
Non-integer weights are rounded to integer in AIC calculation.
value
- (undocumented)public GeneralizedLinearRegression setSolver(String value)
value
- (undocumented)public GeneralizedLinearRegression setLinkPredictionCol(String value)
value
- (undocumented)public GeneralizedLinearRegression copy(ParamMap extra)
Params
defaultCopy()
.copy
in interface Params
copy
in class Predictor<Vector,GeneralizedLinearRegression,GeneralizedLinearRegressionModel>
extra
- (undocumented)public Param<String> family()
public String getFamily()
public DoubleParam variancePower()
public double getVariancePower()
public Param<String> link()
linkPower
.
public String getLink()
public DoubleParam linkPower()
variancePower
, which matches the R "statmod"
package.
public double getLinkPower()
public Param<String> linkPredictionCol()
public String getLinkPredictionCol()
public boolean hasLinkPredictionCol()
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
public StructType validateAndTransformSchema(StructType schema, boolean fitting, DataType featuresDataType)
schema
- input schemafitting
- whether this is in fittingfeaturesDataType
- SQL DataType for FeaturesType.
E.g., VectorUDT
for vector features.public Param<String> labelCol()
public String getLabelCol()
public Param<String> featuresCol()
public String getFeaturesCol()
public Param<String> predictionCol()
public String getPredictionCol()