- class pyspark.mllib.classification.LogisticRegressionModel(weights, intercept, numFeatures, numClasses)[source]#
Classification model trained using Multinomial/Binary Logistic Regression.
New in version 0.9.0.
- Parameters
- weights
Weights computed for every feature.
- interceptfloat
Intercept computed for this model. (Only used in Binary Logistic Regression. In Multinomial Logistic Regression, the intercepts will not be a single value, so the intercepts will be part of the weights.)
- numFeaturesint
The dimension of the features.
- numClassesint
The number of possible outcomes for k classes classification problem in Multinomial Logistic Regression. By default, it is binary logistic regression so numClasses will be set to 2.
- weights
>>> from pyspark.mllib.linalg import SparseVector >>> data = [ ... LabeledPoint(0.0, [0.0, 1.0]), ... LabeledPoint(1.0, [1.0, 0.0]), ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10) >>> lrm.predict([1.0, 0.0]) 1 >>> lrm.predict([0.0, 1.0]) 0 >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect() [1, 0] >>> lrm.clearThreshold() >>> lrm.predict([0.0, 1.0]) 0.279...
>>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10) >>> lrm.predict(numpy.array([0.0, 1.0])) 1 >>> lrm.predict(numpy.array([1.0, 0.0])) 0 >>> lrm.predict(SparseVector(2, {1: 1.0})) 1 >>> lrm.predict(SparseVector(2, {0: 1.0})) 0 >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LogisticRegressionModel.load(sc, path) >>> sameModel.predict(numpy.array([0.0, 1.0])) 1 >>> sameModel.predict(SparseVector(2, {0: 1.0})) 0 >>> from shutil import rmtree >>> try: ... rmtree(path) ... except BaseException: ... pass >>> multi_class_data = [ ... LabeledPoint(0.0, [0.0, 1.0, 0.0]), ... LabeledPoint(1.0, [1.0, 0.0, 0.0]), ... LabeledPoint(2.0, [0.0, 0.0, 1.0]) ... ] >>> data = sc.parallelize(multi_class_data) >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3) >>> mcm.predict([0.0, 0.5, 0.0]) 0 >>> mcm.predict([0.8, 0.0, 0.0]) 1 >>> mcm.predict([0.0, 0.0, 0.3]) 2
Clears the threshold so that predict will output raw prediction scores.
(sc,��path)Load a model from the given path.
(x)Predict values for a single data point or an RDD of points using the model trained.
(sc,��path)Save this model to the given path.
(value)Sets the threshold that separates positive predictions from negative predictions.
Intercept computed for this model.
Number of possible outcomes for k classes classification problem in Multinomial Logistic Regression.
Dimension of the features.
Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions.
Weights computed for every feature.
Methods Documentation
- clearThreshold()#
Clears the threshold so that predict will output raw prediction scores. It is used for binary classification only.
New in version 1.4.0.
- predict(x)[source]#
Predict values for a single data point or an RDD of points using the model trained.
New in version 0.9.0.
- setThreshold(value)#
Sets the threshold that separates positive predictions from negative predictions. An example with prediction score greater than or equal to this threshold is identified as a positive, and negative otherwise. It is used for binary classification only.
New in version 1.4.0.
Attributes Documentation
- intercept#
Intercept computed for this model.
New in version 1.0.0.
- numClasses#
Number of possible outcomes for k classes classification problem in Multinomial Logistic Regression.
New in version 1.4.0.
- numFeatures#
Dimension of the features.
New in version 1.4.0.
- threshold#
Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. It is used for binary classification only.
New in version 1.4.0.
- weights#
Weights computed for every feature.
New in version 1.0.0.