Я видел похожие сообщения, но не получил полного ответа, поэтому размещал здесь.
Я использую TF-IDF в Spark, чтобы получить слово в документе с максимальным значением tf-idf. Я использую следующий фрагмент кода.
from pyspark.ml.feature import HashingTF, IDF, Tokenizer, CountVectorizer, StopWordsRemover
tokenizer = Tokenizer(inputCol="doc_cln", outputCol="tokens")
remover1 = StopWordsRemover(inputCol="tokens",
outputCol="stopWordsRemovedTokens")
stopwordList =["word1","word2","word3"]
remover2 = StopWordsRemover(inputCol="stopWordsRemovedTokens",
outputCol="filtered" ,stopWords=stopwordList)
hashingTF = HashingTF(inputCol="filtered", outputCol="rawFeatures", numFeatures=2000)
idf = IDF(inputCol="rawFeatures", outputCol="features", minDocFreq=5)
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[tokenizer, remover1, remover2, hashingTF, idf])
model = pipeline.fit(df)
results = model.transform(df)
results.cache()
Я получаю результаты как
|[a8g4i9g5y, hwcdn] |(2000,[905,1104],[7.34977707433047,7.076179741760428])
где
filtered: array (nullable = true)
features: vector (nullable = true)
Как я могу получить массив, извлеченный из "функции"? В идеале я хотел бы получить слово, соответствующее самому высокому tfidf, например, ниже
|a8g4i9g5y|7.34977707433047
Заранее спасибо!
Ваша feature
столбец имеет тип vector
из пакета pyspark.ml.linalg
. Это может быть либо
На основании (2000,[905,1104],[7.34977707433047,7.076179741760428])
вас данных (2000,[905,1104],[7.34977707433047,7.076179741760428])
, по-видимому, это SparseVector
, и его можно разбить на 3 основных компонента:
size
: 2000
indices
: [905,1104]
values
: [7.34977707433047,7.076179741760428]
И то, что вы ищете, это values
свойств этого вектора.
С другим "литералом" типа PySpark SQL, таким как StringType
или IntegerType
, вы можете получить доступ к его свойствам (и функциям агрегации) с помощью пакета функций SQL (docs). Однако vector
не является буквальным типом SQL и единственным способом доступа к его свойствам является UDF, например:
# Important: 'vector.values' returns ndarray from numpy.
# PySpark doesn't understand ndarray, therefore you'd want to
# convert it to normal Python list using 'tolist'
def extract_values_from_vector(vector):
return vector.values.tolist()
# Just a regular UDF
def extract_values_from_vector_udf(col):
return udf(extract_values_from_vector, ArrayType(DoubleType()))
# And use that UDF to get your values
results.select(extract_values_from_vector_udf('features'), 'features')
a8g4i9g5y
связано с признаком 905 и, следовательно, имеет значение tf-idf 7.34977707433047. Процесс хеширования не обязательно поддерживает порядок слов в этом конкретном предложении. Вы можете быть только уверены, чтоa8g4i9g5y
илиhwcdn
представлены столбцом 905, а другой представлен1104
.